def __getitem__(self, idx):
        sid = self.synset_ids[idx]
        mid = self.model_ids[idx]
        iid = self.image_ids[idx]

        # Always read metadata for this model; TODO cache in __init__?
        metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RT = metadata["extrinsics"][iid]
        img_path = metadata["image_list"][iid]
        img_path = os.path.join(self.data_dir, sid, mid, "images", img_path)

        # Load the image
        with open(img_path, "rb") as f:
            img = Image.open(f).convert("RGB")
        img = self.transform(img)

        # Maybe read mesh
        verts, faces = None, None
        if self.return_mesh:
            mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt")
            mesh_data = torch.load(mesh_path)
            verts, faces = mesh_data["verts"], mesh_data["faces"]
            verts = project_verts(verts, RT)

        # Maybe use cached samples
        points, normals = None, None
        if not self.sample_online:
            samples = self.mid_to_samples.get(mid, None)
            if samples is None:
                # They were not cached in memory, so read off disk
                samples_path = os.path.join(self.data_dir, sid, mid,
                                            "samples.pt")
                samples = torch.load(samples_path)
            points = samples["points_sampled"]
            normals = samples["normals_sampled"]
            idx = torch.randperm(points.shape[0])[:self.num_samples]
            points, normals = points[idx], normals[idx]
            points = project_verts(points, RT)
            normals = normals.mm(
                RT[:3, :3].t())  # Only rotate, don't translate

        voxels, P = None, None
        if self.voxel_size > 0:
            # Use precomputed voxels if we have them, otherwise return voxel_coords
            # and we will compute voxels in postprocess
            voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid)
            voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file)
            if os.path.isfile(voxel_file):
                voxels = torch.load(voxel_file)
            else:
                voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt")
                voxel_data = torch.load(voxel_path)
                voxels = voxel_data["voxel_coords"]
                P = K.mm(RT)

        id_str = "%s-%s-%02d" % (sid, mid, iid)
        return img, verts, faces, points, normals, voxels, P, id_str
Esempio n. 2
0
def voxelize(voxel_coords, P, V):
    device = voxel_coords.device
    voxel_coords = project_verts(voxel_coords, P)

    # Using the actual zmin and zmax of the model is bad because we need them
    # to perform the inverse transform, which transform voxels back into world
    # space for refinement or evaluation. Instead we use an empirical min and
    # max over the dataset; that way it is consistent for all images.
    zmin = SHAPENET_MIN_ZMIN
    zmax = SHAPENET_MAX_ZMAX

    # Once we know zmin and zmax, we need to adjust the z coordinates so the
    # range [zmin, zmax] instead runs from [-1, 1]
    m = 2.0 / (zmax - zmin)
    b = -2.0 * zmin / (zmax - zmin) - 1
    voxel_coords[:, 2].mul_(m).add_(b)
    voxel_coords[:, 1].mul_(-1)  # Flip y

    # Now voxels are in [-1, 1]^3; map to [0, V-1)^3
    voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0)
    voxel_coords = voxel_coords.round().to(torch.int64)
    valid = (0 <= voxel_coords) * (voxel_coords < V)
    valid = valid[:, 0] * valid[:, 1] * valid[:, 2]
    x, y, z = voxel_coords.unbind(dim=1)
    x, y, z = x[valid], y[valid], z[valid]
    voxels = torch.zeros(V, V, V, dtype=torch.uint8, device=device)
    voxels[z, y, x] = 1

    return voxels
Esempio n. 3
0
 def sample_points_normals(self, data_dir, sid, mid, RT):
     samples = self.mid_to_samples.get(mid, None)
     if samples is None:
         # They were not cached in memory, so read off disk
         samples_path = os.path.join(data_dir, sid, mid, "samples.pt")
         samples = torch.load(samples_path)
     points = samples["points_sampled"]
     normals = samples["normals_sampled"]
     idx = torch.randperm(points.shape[0])[: self.num_samples]
     points, normals = points[idx], normals[idx]
     points = project_verts(points, RT)
     normals = normals.mm(RT[:3, :3].t())  # Only rotate, don't translate
     return points, normals
    def forward(self, img_feats, meshes, vert_feats=None, P=None):
        """
        Args:
          img_feats (tensor): Features from the backbone
          meshes (Meshes): Initial meshes which will get refined
          vert_feats (tensor): Features from the previous refinement stage
          P (tensor): Tensor of shape (N, 4, 4) giving projection matrix to be applied
                      to vertex positions before vert-align. If None, don't project verts.
        """
        # Project verts if we are making predictions in world space
        verts_padded_to_packed_idx = meshes.verts_padded_to_packed_idx()

        if P is not None:
            vert_pos_padded = project_verts(meshes.verts_padded(), P)
            vert_pos_packed = _padded_to_packed(vert_pos_padded,
                                                verts_padded_to_packed_idx)
        else:
            vert_pos_padded = meshes.verts_padded()
            vert_pos_packed = meshes.verts_packed()

        # flip y coordinate
        device, dtype = vert_pos_padded.device, vert_pos_padded.dtype
        factor = torch.tensor([1, -1, 1], device=device,
                              dtype=dtype).view(1, 1, 3)
        vert_pos_padded = vert_pos_padded * factor
        # Get features from the image
        #         print(img_feats[0].shape)
        #         print(vert_pos_padded.shape)
        vert_align_feats = vert_align(img_feats, vert_pos_padded)
        vert_align_feats = _padded_to_packed(vert_align_feats,
                                             verts_padded_to_packed_idx)
        vert_align_feats = F.relu(self.bottleneck(vert_align_feats))

        # Prepare features for first graph conv layer
        first_layer_feats = [vert_align_feats, vert_pos_packed]
        if vert_feats is not None:
            first_layer_feats.append(vert_feats)
        vert_feats = torch.cat(first_layer_feats, dim=1)

        # Run graph conv layers
        for gconv in self.gconvs:
            vert_feats_nopos = F.relu(gconv(vert_feats, meshes.edges_packed()))
            vert_feats = torch.cat([vert_feats_nopos, vert_pos_packed], dim=1)

        # Predict a new mesh by offsetting verts
        vert_offsets = torch.tanh(self.vert_offset(vert_feats))
        meshes_out = meshes.offset_verts(vert_offsets)

        return meshes_out, vert_feats_nopos
Esempio n. 5
0
    def _voxelize(self, voxel_coords, P):
        V = self.voxel_size
        device = voxel_coords.device
        voxel_coords = project_verts(voxel_coords, P)

        # In the original coordinate system, the object fits in a unit sphere
        # centered at the origin. Thus after transforming by RT, it will fit
        # in a unit sphere centered at T = RT[:, 3] = (0, 0, RT[2, 3]). We need
        # to figure out what the range of z will be after being further
        # transformed by K. We can work this out explicitly.
        # z0 = RT[2, 3].item()
        # zp, zm = z0 - 0.5, z0 + 0.5
        # k22, k23 = K[2, 2].item(), K[2, 3].item()
        # k32, k33 = K[3, 2].item(), K[3, 3].item()
        # zmin = (zm * k22 + k23) / (zm * k32 + k33)
        # zmax = (zp * k22 + k23) / (zp * k32 + k33)

        # Using the actual zmin and zmax of the model is bad because we need them
        # to perform the inverse transform, which transform voxels back into world
        # space for refinement or evaluation. Instead we use an empirical min and
        # max over the dataset; that way it is consistent for all images.
        zmin = SHAPENET_MIN_ZMIN
        zmax = SHAPENET_MAX_ZMAX

        # Once we know zmin and zmax, we need to adjust the z coordinates so the
        # range [zmin, zmax] instead runs from [-1, 1]
        m = 2.0 / (zmax - zmin)
        b = -2.0 * zmin / (zmax - zmin) - 1
        voxel_coords[:, 2].mul_(m).add_(b)
        voxel_coords[:, 1].mul_(-1)  # Flip y

        # Now voxels are in [-1, 1]^3; map to [0, V-1)^3
        voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0)
        voxel_coords = voxel_coords.round().to(torch.int64)
        valid = (0 <= voxel_coords) * (voxel_coords < V)
        valid = valid[:, 0] * valid[:, 1] * valid[:, 2]
        x, y, z = voxel_coords.unbind(dim=1)
        x, y, z = x[valid], y[valid], z[valid]
        voxels = torch.zeros(V, V, V, dtype=torch.int64, device=device)
        voxels[z, y, x] = 1

        return voxels
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):

    if comm.is_main_process():
        wandb.init(project='MeshRCNN', config=cfg, name='meshrcnn')

    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)

    # Zhengyuan modification
    loss_predictor = LossPredictionModule().to(device)
    loss_pred_optim = torch.optim.SGD(loss_predictor.parameters(),
                                      lr=1e-3,
                                      momentum=0.9)

    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        # Config settings for renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).float().to(device)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()

            batch = loaders["train"].postprocess(batch, device)
            if dataset == 'MeshVoxMulti':
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _imgs, render_RTs, RTs = batch
            else:
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            # NOTE: _imgs contains all of the other images in belonging to this model
            # We have to select the next-best-view from that list of images

            num_infinite_params = 0
            for p in params:
                num_infinite_params += (torch.isfinite(
                    p.data) == 0).sum().item()
            if num_infinite_params > 0:
                msg = "ERROR: Model has %d non-finite params (before forward!)"
                logger.info(msg % num_infinite_params)
                return

            model_kwargs = {}
            if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                model_kwargs["voxel_only"] = True
            with Timer("Forward"):
                voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            num_infinite = 0
            for cur_meshes in meshes_pred:
                cur_verts = cur_meshes.verts_packed()
                num_infinite += (torch.isfinite(cur_verts) == 0).sum().item()
            if num_infinite > 0:
                logger.info("ERROR: Got %d non-finite verts" % num_infinite)
                return

            total_silh_loss = torch.tensor(
                0.)  # Total silhouette loss, to be added to "loss" below
            # Voxel only training for first few iterations
            if not meshes_gt is None and not model_kwargs.get(
                    "voxel_only", False):
                import pdb
                pdb.set_trace()
                _meshes_pred = meshes_pred[-1].clone()
                _meshes_gt = meshes_gt[-1].clone()

                # Render masks from predicted mesh for each view
                # GT probability map to supervise prediction module
                B = len(meshes_gt)
                probability_map = 0.01 * torch.ones((B, 24))  # batch size x 24
                for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(
                        zip(meshes_gt, _meshes_pred)):
                    # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint
                    RT = RTs[b]
                    # Rotate 90 degrees about y-axis and invert
                    invRT = torch.inverse(RT.mm(rot_y_90))
                    invRT_no_rot = torch.inverse(RT)  # Just invert

                    cur_pred_mesh._verts_list[0] = project_verts(
                        cur_pred_mesh._verts_list[0], invRT)
                    sid = id_strs[b].split('-')[0]

                    # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
                    if sid == '02958343':
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT_no_rot)
                    else:
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT)

                    for iid in range(len(render_RTs[b])):

                        R = render_RTs[b][iid][:3, :3].unsqueeze(0)
                        T = render_RTs[b][iid][:3, 3].unsqueeze(0)
                        cameras = OpenGLPerspectiveCameras(device=device,
                                                           R=R,
                                                           T=T)
                        silhouette_renderer = MeshRenderer(
                            rasterizer=MeshRasterizer(
                                cameras=cameras,
                                raster_settings=raster_settings),
                            shader=SoftSilhouetteShader(
                                blend_params=blend_params))

                        ref_image = silhouette_renderer(
                            meshes_world=cur_gt_mesh, R=R, T=T)
                        image = silhouette_renderer(meshes_world=cur_pred_mesh,
                                                    R=R,
                                                    T=T)
                        '''
                        import matplotlib.pyplot as plt
                        plt.subplot(1,2,1)
                        plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy())
                        plt.subplot(1,2,2)
                        plt.imshow(image[0,:,:,3].detach().cpu().numpy())
                        plt.show()
                        '''

                        # MSE Loss between both silhouettes
                        silh_loss = torch.sum(
                            (image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
                        probability_map[b, iid] = silh_loss.detach()

                        total_silh_loss += silh_loss

                probability_map = torch.nn.functional.softmax(
                    probability_map, dim=1)  # Softmax across images
                nbv_idx = torch.argmax(probability_map,
                                       dim=1)  # Next-best view indices
                nbv_imgs = _imgs[torch.arange(B),
                                 nbv_idx]  # Next-best view images

                # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction
                # The input should be the first image and the next-best view
                #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs)

            loss, losses = None, {}
            if num_infinite == 0:
                loss, losses = loss_fn(voxel_scores, meshes_pred, voxels_gt,
                                       (points_gt, normals_gt))
            skip = loss is None
            if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0:
                logger.info("WARNING: Got non-finite loss %f" % loss)
                skip = True

            # Add silhouette loss to total loss
            silh_weight = 1.0  # TODO: Add a weight for the silhouette loss?
            loss = loss + total_silh_loss * silh_weight
            losses['silhouette'] = total_silh_loss

            if model_kwargs.get("voxel_only", False):
                for k, v in losses.items():
                    if k != "voxel":
                        losses[k] = 0.0 * v

            if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0:
                if comm.is_main_process():
                    cp.store_metric(loss=loss.item())
                    str_out = "Iteration: %d, epoch: %d, lr: %.5f," % (
                        cp.t,
                        cp.epoch,
                        optimizer.param_groups[0]["lr"],
                    )
                    for k, v in losses.items():
                        str_out += "  %s loss: %.4f," % (k, v.item())
                    str_out += "  total loss: %.4f," % loss.item()

                    # memory allocaged
                    if torch.cuda.is_available():
                        max_mem_mb = torch.cuda.max_memory_allocated(
                        ) / 1024.0 / 1024.0
                        str_out += " mem: %d" % max_mem_mb

                    if len(meshes_pred) > 0:
                        mean_V = meshes_pred[-1].num_verts_per_mesh().float(
                        ).mean().item()
                        mean_F = meshes_pred[-1].num_faces_per_mesh().float(
                        ).mean().item()
                        str_out += ", mesh size = (%d, %d)" % (mean_V, mean_F)
                    logger.info(str_out)

                # Log with Weights & Biases, comment out if not installed
                wandb.log(losses)

            if loss_moving_average is None and loss is not None:
                loss_moving_average = loss.item()

            # Skip backprop for this batch if the loss is above the skip factor times
            # the moving average for losses
            if loss is None:
                pass
            elif loss.item(
            ) > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average:
                logger.info("Warning: Skipping loss %f on GPU %d" %
                            (loss.item(), comm.get_rank()))
                cp.store_metric(losses_skipped=loss.item())
                skip = True
            else:
                # Update the moving average of our loss
                gamma = cfg.SOLVER.LOSS_SKIP_GAMMA
                loss_moving_average *= gamma
                loss_moving_average += (1.0 - gamma) * loss.item()
                cp.store_data("loss_moving_average", loss_moving_average)

            if skip:
                logger.info("Dummy backprop on GPU %d" % comm.get_rank())
                loss = 0.0 * sum(p.sum() for p in params)

            # Backprop and step
            scheduler.step()
            optimizer.zero_grad()
            with Timer("Backward"):
                loss.backward()

            # Zhengyuan step loss_prediction
            loss_predictor.train_batch(image, probability_map, loss_pred_optim)

            # When training with normal loss, sometimes I get NaNs in gradient that
            # cause the model to explode. Check for this before performing a gradient
            # update. This is safe in mult-GPU since gradients have already been
            # summed, so each GPU has the same gradients.
            num_infinite_grad = 0
            for p in params:
                num_infinite_grad += (torch.isfinite(p.grad) == 0).sum().item()
            if num_infinite_grad == 0:
                optimizer.step()
            else:
                msg = "WARNING: Got %d non-finite elements in gradient; skipping update"
                logger.info(msg % num_infinite_grad)
            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, dataset, "test", multigpu=False)
        evaluate_test(model, test_loader)
Esempio n. 7
0
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):

    #if comm.is_main_process():
    #    wandb.init(project='MeshRCNN', config=cfg, name='prediction_module')

    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)

    # Zhengyuan modification
    loss_predictor = LossPredictionModule().to(device)
    loss_pred_optim = torch.optim.Adam(loss_predictor.parameters(), lr=1e-5)

    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        # Config settings for renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).float().to(device)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()

            batch = loaders["train"].postprocess(batch, device)
            if dataset == 'MeshVoxMulti':
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _, render_RTs, RTs = batch
            else:
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            with inference_context(model):
                # NOTE: _imgs contains all of the other images in belonging to this model
                # We have to select the next-best-view from that list of images

                model_kwargs = {}
                if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                    model_kwargs["voxel_only"] = True
                with Timer("Forward"):
                    voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            total_silh_loss = torch.tensor(
                0.)  # Total silhouette loss, to be added to "loss" below
            # Voxel only training for first few iterations
            if not meshes_gt is None and not model_kwargs.get(
                    "voxel_only", False):
                _meshes_pred = meshes_pred[-1].clone()
                _meshes_gt = meshes_gt[-1].clone()

                # Render masks from predicted mesh for each view
                # GT probability map to supervise prediction module
                B = len(meshes_gt)
                probability_map = 0.01 * torch.ones(
                    (B, 24)).to(device)  # batch size x 24
                viewgrid = torch.zeros(
                    (B, 24, render_image_size,
                     render_image_size)).to(device)  # batch size x 24 x H x W
                for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(
                        zip(meshes_gt, _meshes_pred)):
                    # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint
                    RT = RTs[b]
                    # Rotate 90 degrees about y-axis and invert
                    invRT = torch.inverse(RT.mm(rot_y_90))
                    invRT_no_rot = torch.inverse(RT)  # Just invert

                    cur_pred_mesh._verts_list[0] = project_verts(
                        cur_pred_mesh._verts_list[0], invRT)
                    sid = id_strs[b].split('-')[0]

                    # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
                    if sid == '02958343':
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT_no_rot)
                    else:
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT)

                    for iid in range(len(render_RTs[b])):

                        R = render_RTs[b][iid][:3, :3].unsqueeze(0)
                        T = render_RTs[b][iid][:3, 3].unsqueeze(0)
                        cameras = OpenGLPerspectiveCameras(device=device,
                                                           R=R,
                                                           T=T)
                        silhouette_renderer = MeshRenderer(
                            rasterizer=MeshRasterizer(
                                cameras=cameras,
                                raster_settings=raster_settings),
                            shader=SoftSilhouetteShader(
                                blend_params=blend_params))

                        ref_image = (silhouette_renderer(
                            meshes_world=cur_gt_mesh, R=R, T=T) > 0).float()
                        image = (silhouette_renderer(
                            meshes_world=cur_pred_mesh, R=R, T=T) > 0).float()

                        #Add image silhouette to viewgrid
                        viewgrid[b, iid] = image[..., -1]
                        '''
                        import matplotlib.pyplot as plt
                        plt.subplot(1,2,1)
                        plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy())
                        plt.subplot(1,2,2)
                        plt.imshow(image[0,:,:,3].detach().cpu().numpy())
                        plt.show()
                        '''

                        # MSE Loss between both silhouettes
                        silh_loss = torch.sum(
                            (image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
                        probability_map[b, iid] = silh_loss.detach()

                        total_silh_loss += silh_loss

                probability_map = probability_map / (torch.max(
                    probability_map, dim=1)[0].unsqueeze(1))  # Normalize

                probability_map = torch.nn.functional.softmax(
                    probability_map, dim=1).to(device)  # Softmax across images
                #nbv_idx = torch.argmax(probability_map, dim=1)  # Next-best view indices
                #nbv_imgs = _imgs[torch.arange(B), nbv_idx]  # Next-best view images

                # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction
                # The input should be the first image and the next-best view
                #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs)

                # Zhengyuan step loss_prediction
                predictor_loss = loss_predictor.train_batch(
                    viewgrid, probability_map, loss_pred_optim)
                if comm.is_main_process():
                    #wandb.log({'prediction module loss':predictor_loss})

                    if cp.t % 50 == 0:
                        print('{} predictor_loss: {}'.format(
                            cp.t, predictor_loss))

                    #Save checkpoint every t iteration
                    if cp.t % 500 == 0:
                        print(
                            'Saving loss prediction module at iter {}'.format(
                                cp.t))
                        os.makedirs('./output_prediction_module',
                                    exist_ok=True)
                        torch.save(
                            loss_predictor.state_dict(),
                            './output_prediction_module/prediction_module_' +
                            str(cp.t) + '.pth')

            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, dataset, "test", multigpu=False)
        evaluate_test(model, test_loader)
Esempio n. 8
0
    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image.to(self.device))

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"].to(self.device)
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT)

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics'].to(self.device)

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb.to(self.device))
        mesh.textures = textures

        plt.figure(figsize=(10, 10))

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )

        gt_verts = gt_verts.to(self.device)
        gt_faces = gt_faces.to(self.device)
        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.to(self.device))
        else:
            gt_verts = project_verts(gt_verts, invRT.to(self.device))
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        probability_map = 0.01 * torch.ones((1, 24))
        viewgrid = torch.zeros(
            (1, 24, render_image_size, render_image_size)).to(self.device)
        fig = plt.figure(1)
        ax_pred = [fig.add_subplot(5, 5, i + 1) for i in range(24)]
        #fig = plt.figure(2)
        #ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]

        for i in range(len(render_RTs)):
            if i == iid:  #Don't include current view
                continue

            R = render_RTs[i][:3, :3].unsqueeze(0)
            T = render_RTs[i][:3, 3].unsqueeze(0)
            cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)

            silhouette_renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))

            ref_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
            silhouette_image = (silhouette_renderer(
                meshes_world=mesh, R=R, T=T) > 0).float()

            # MSE Loss between both silhouettes
            silh_loss = torch.sum(
                (silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
            probability_map[0, i] = silh_loss.detach()

            viewgrid[0, i] = silhouette_image[..., -1]

            #ax_gt[i].imshow(ref_image[0,:,:,3].cpu().numpy())
            #ax_gt[i].set_title(i)

            ax_pred[i].imshow(silhouette_image[0, :, :, 3].cpu().numpy())
            ax_pred[i].set_title(i)

        img = image_to_numpy(deprocess(image[0]))
        #ax_gt[iid].imshow(img)
        ax_pred[iid].imshow(img)
        #fig = plt.figure(3)
        #ax = fig.add_subplot(111)
        #ax.imshow(img)

        pred_prob_map = self.loss_predictor(viewgrid)
        print('Highest actual loss: {}'.format(torch.argmax(probability_map)))
        print('Highest predicted loss: {}'.format(torch.argmax(pred_prob_map)))
        plt.show()
      def run_on_image(self, imgs, id_strs, meshes_gt, render_RTs, RTs):
        deprocess = imagenet_deprocess(rescale_image=False)

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=50, 
        )   

        rot_y_90 = torch.tensor([[0, 0, 1, 0], 
                                    [0, 1, 0, 0], 
                                    [-1, 0, 0, 0], 
                                    [0, 0, 0, 1]]).to(RTs) 
        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(imgs)

        B,_,H,W = imgs.shape
        probability_map = 0.01 * torch.ones((B, 24)).to(self.device)
        viewgrid = torch.zeros((B,24,render_image_size,render_image_size)).to(device) # batch size x 24 x H x W
        _meshes_pred = meshes_pred[-1]

        for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(zip(meshes_gt, _meshes_pred)):
            sid = id_strs[b].split('-')[0]

            RT = RTs[b]

            #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
            #for the GT mesh
            invRT = torch.inverse(RT.mm(rot_y_90))
            invRT_no_rot = torch.inverse(RT)

            cur_pred_mesh._verts_list[0] = project_verts(cur_pred_mesh._verts_list[0], invRT)
            #Invert without the rotation for the vehicle class
            if sid == '02958343':
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT_no_rot)
            else:
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT)

            '''
            plt.figure(figsize=(10, 10))

            fig = plt.figure(1)
            ax_pred = [fig.add_subplot(5,5,i+1) for i in range(24)]
            fig = plt.figure(2)
            ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]
            '''

            for iid in range(len(render_RTs)):

                R = render_RTs[b][iid][:3,:3].unsqueeze(0)
                T = render_RTs[b][iid][:3,3].unsqueeze(0)
                cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)
                silhouette_renderer = MeshRenderer(
                    rasterizer=MeshRasterizer(
                    cameras=cameras, 
                    raster_settings=raster_settings
                    ),  
                shader=SoftSilhouetteShader(blend_params=blend_params)
                )   

                ref_image        = (silhouette_renderer(meshes_world=cur_gt_mesh, R=R, T=T)>0).float()
                silhouette_image = (silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T)>0).float()

                #Add image silhouette to viewgrid
                viewgrid[b,iid] = silhouette_image[...,-1]

                # MSE Loss between both silhouettes
                silh_loss = torch.sum((silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3]) ** 2)
                probability_map[b, iid] = silh_loss.detach()

                '''
                ax_gt[iid].imshow(ref_image[0,:,:,3].cpu().numpy())
                ax_gt[iid].set_title(iid)
                ax_pred[iid].imshow(silhouette_image[0,:,:,3].cpu().numpy())
                ax_pred[iid].set_title(iid)
                '''

            '''
            img = image_to_numpy(deprocess(imgs[b]))
            ax_gt[iid].imshow(img)
            ax_pred[iid].imshow(img)
            fig = plt.figure(3)
            ax = fig.add_subplot(111)
            ax.imshow(img)

            #plt.show()
            '''

        probability_map = probability_map/(torch.max(probability_map, dim=1)[0].unsqueeze(1)) # Normalize
        probability_map = torch.nn.functional.softmax(probability_map, dim=1) # Softmax across images
        pred_prob_map = self.loss_predictor(viewgrid)

        gt_max = torch.argmax(probability_map, dim=1)
        pred_max = torch.argmax(pred_prob_map, dim=1)

        #print('--'*30)
        #print('Item: {}'.format(id_str))
        #print('Highest actual loss: {}'.format(gt_max))
        #print('Highest predicted loss: {}'.format(pred_max))

        #print('GT prob map: {}'.format(probability_map.squeeze()))
        #print('Pred prob map: {}'.format(pred_prob_map.squeeze()))

        correct = torch.sum(pred_max == gt_max).item()
        total   = len(pred_prob_map)

        return correct, total
Esempio n. 10
0
    def __getitem__(self, idx):
        sid = self.synset_ids[idx]
        mid = self.model_ids[idx]
        iid = self.image_ids[idx]

        # Always read metadata for this model; TODO cache in __init__?
        metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RT = metadata["extrinsics"][iid]
        img_path = metadata["image_list"][iid]
        img_path = os.path.join(self.data_dir, sid, mid, "images", img_path)

        # Load the image
        with open(img_path, "rb") as f:
            img = Image.open(f).convert("RGB")
        img = self.transform(img)

        # All other rendered images for this model. Needed for "viewgrid"
        _iids = [
            self.image_ids[i] for i in self.mid_to_idx[sid + '_' + mid]
            if i != idx
        ]
        _imgs = []
        for _iid in _iids:
            img_path = metadata["image_list"][_iid]
            img_path = os.path.join(self.data_dir, sid, mid, "images",
                                    img_path)

            with open(img_path, "rb") as f:
                _img = Image.open(f).convert("RGB")
            _imgs.append(self.transform(_img))
        _imgs = torch.stack(_imgs)  # N x C x H x W

        #Zero pad so N = 23. 24 images minus the current image
        N, C, H, W = _imgs.shape
        n = 23
        d = 0
        if len(_imgs) < n:
            d = n - len(_imgs)
            _imgs = torch.cat((_imgs, torch.zeros(d, C, H, W)))

        #Mask for zero-padded images. 1 for valid image, 0 for zero image
        mask = torch.cat((torch.ones(N), torch.zeros(d)))

        #Get transformation matrices to view renderings from same camera position
        render_metadata_path = os.path.join(self.rendering_dir, sid, mid,
                                            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics']

        #Only keep matrices for all "other" images
        idxs = torch.arange(n + 1)
        keep = (idxs != iid)
        render_RTs = render_RTs[keep]  # N x 3 x 4

        # Maybe read mesh
        verts, faces = None, None
        if self.return_mesh:
            mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt")
            mesh_data = torch.load(mesh_path)
            verts, faces = mesh_data["verts"], mesh_data["faces"]
            verts = project_verts(verts, RT)

        # Maybe use cached samples
        points, normals = None, None
        if not self.sample_online:
            samples = self.mid_to_samples.get(mid, None)
            if samples is None:
                # They were not cached in memory, so read off disk
                samples_path = os.path.join(self.data_dir, sid, mid,
                                            "samples.pt")
                samples = torch.load(samples_path)
            points = samples["points_sampled"]
            normals = samples["normals_sampled"]
            idx = torch.randperm(points.shape[0])[:self.num_samples]
            points, normals = points[idx], normals[idx]
            points = project_verts(points, RT)
            normals = normals.mm(
                RT[:3, :3].t())  # Only rotate, don't translate

        voxels, P = None, None
        if self.voxel_size > 0:
            # Use precomputed voxels if we have them, otherwise return voxel_coords
            # and we will compute voxels in postprocess
            voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid)
            voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file)
            if os.path.isfile(voxel_file):
                voxels = torch.load(voxel_file)
            else:
                voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt")
                voxel_data = torch.load(voxel_path)
                voxels = voxel_data["voxel_coords"]
                P = K.mm(RT)

        id_str = "%s-%s-%02d" % (sid, mid, iid)
        return img, verts, faces, points, normals, voxels, P, id_str, _imgs, render_RTs, RT
Esempio n. 11
0
 def read_mesh(data_dir, sid, mid, RT):
     mesh_path = os.path.join(data_dir, sid, mid, "mesh.pt")
     mesh_data = torch.load(mesh_path)
     verts, faces = mesh_data["verts"], mesh_data["faces"]
     verts = project_verts(verts, RT)
     return verts, faces
    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image)

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"]
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT.cpu())

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics']

        plt.figure(figsize=(10, 10))
        R = render_RTs[iid][:3, :3].unsqueeze(0)
        T = render_RTs[iid][:3, 3].unsqueeze(0)
        cameras = OpenGLPerspectiveCameras(R=R, T=T)

        #Phong Renderer
        lights = PointLights(location=[[0.0, 0.0, -3.0]])
        raster_settings = RasterizationSettings(image_size=256,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=0)
        phong_renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=cameras, raster_settings=raster_settings),
                                      shader=HardPhongShader(lights=lights))

        #Silhouette Renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=SoftSilhouetteShader(blend_params=blend_params))

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        mesh.textures = textures

        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.cpu())
        else:
            gt_verts = project_verts(gt_verts, invRT.cpu())
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        img = image_to_numpy(deprocess(image[0]))
        mesh_image = phong_renderer(meshes_world=mesh, R=R, T=T)
        gt_silh_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
        silhouette_image = (silhouette_renderer(meshes_world=mesh, R=R, T=T) >
                            0).float()

        plt.subplot(2, 2, 1)
        plt.imshow(img)
        plt.title('input image')
        plt.subplot(2, 2, 2)
        plt.imshow(mesh_image[0, ..., :3].cpu().numpy())
        plt.title('rendered mesh')
        plt.subplot(2, 2, 3)
        plt.imshow(gt_silh_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of gt mesh')
        plt.subplot(2, 2, 4)
        plt.imshow(silhouette_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of rendered mesh')

        plt.show()
        #plt.savefig('./output_demo/figures/'+id_str+'.png')

        vis_utils.visualize_prediction(id_str, img, mesh, self.output_dir)
verts, faces_idx, _ = load_obj(obj_filename)

verts, faces_idx, _ = load_obj(obj_filename)
faces = faces_idx.verts_idx
'''

##########Predicted Mesh
#Generate this by running: python demo/demo_modified.py --config-file configs/shapenet/voxmesh_R50.yaml --data_dir datasets/shapenet/ShapeNetV1processed --output output_demo --checkpoint shapenet://voxmesh_R50.pth --index 0
obj_filename = './output_demo/results_shapenet/02958343-4856ef1e80d356d111f983eb293b51a-00.obj'

verts, faces_idx, _ = load_obj(obj_filename)
faces = faces_idx.verts_idx
invRT = torch.inverse(RTs[0].mm(rot_y_90))
#invRT = torch.inverse(RTs[0].mm(rot_x_n90)) 
#invRT = torch.inverse(RTs[0]) 
verts = project_verts(verts, invRT.cpu())

##########GT Mesh 
'''
mesh_path = os.path.join('./datasets/shapenet/ShapeNetV1processed', sid, mid, "mesh.pt")
mesh_data = torch.load(mesh_path)
verts, faces = mesh_data["verts"], mesh_data["faces"]
verts = project_verts(verts, RTs[0].cpu())
'''


verts_rgb = torch.ones_like(verts)[None] 
textures = Textures(verts_rgb=verts_rgb.to(device))
# print(verts_rgb.shape, verts.shape)
mesh = Meshes(
    verts=[verts.to(device)],   
Esempio n. 14
0
    def __getitem__(self, idx):
        sid = self.synset_ids[idx]
        mid = self.model_ids[idx]
        iid = self.image_ids[idx]

        # Always read metadata for this model; TODO cache in __init__?
        metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RT = metadata["extrinsics"][iid]
        img_path = metadata["image_list"][iid]
        img_path = os.path.join(self.data_dir, sid, mid, "images", img_path)

        # Load the image
        with open(img_path, "rb") as f:
            img = Image.open(f).convert("RGB")
        img = self.transform(img)

        # All images for this model
        # CAN add a variable to control the number of images
        _iids = [self.image_ids[i] for i in self.mid_to_idx[mid]]
        _imgs = []
        for _iid in _iids:
            img_path = metadata["image_list"][_iid]
            img_path = os.path.join(self.data_dir, sid, mid, "images",
                                    img_path)

            with open(img_path, "rb") as f:
                _img = Image.open(f).convert("RGB")
            _imgs.append(self.transform(_img))
        _imgs = torch.stack(_imgs)
        # tensor([N,3,224,224])
        ##
        #         N,C,H,W = _imgs.shape
        #         if N != 24:
        #             n = 24-N
        #             padding = torch.zeros((n,C,H,W),dtype = _imgs.dtype, device = _imgs.device)
        #             _imgs = torch.cat(_imgs, padding)
        ##

        # Maybe read mesh
        verts, faces = None, None
        if self.return_mesh:
            mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt")
            mesh_data = torch.load(mesh_path)
            verts, faces = mesh_data["verts"], mesh_data["faces"]
            verts = project_verts(verts, RT)

        # Maybe use cached samples
        points, normals = None, None
        if not self.sample_online:
            samples = self.mid_to_samples.get(mid, None)
            if samples is None:
                # They were not cached in memory, so read off disk
                samples_path = os.path.join(self.data_dir, sid, mid,
                                            "samples.pt")
                samples = torch.load(samples_path)
            points = samples["points_sampled"]
            normals = samples["normals_sampled"]
            idx = torch.randperm(points.shape[0])[:self.num_samples]
            points, normals = points[idx], normals[idx]
            points = project_verts(points, RT)
            normals = normals.mm(
                RT[:3, :3].t())  # Only rotate, don't translate

        voxels, P = None, None
        if self.voxel_size > 0:
            # Use precomputed voxels if we have them, otherwise return voxel_coords
            # and we will compute voxels in postprocess
            voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid)
            voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file)
            if os.path.isfile(voxel_file):
                voxels = torch.load(voxel_file)
            else:
                voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt")
                voxel_data = torch.load(voxel_path)
                voxels = voxel_data["voxel_coords"]
                P = K.mm(RT)

        id_str = "%s-%s-%02d" % (sid, mid, iid)
        return img, verts, faces, points, normals, voxels, P, id_str, _imgs