def __getitem__(self, idx): sid = self.synset_ids[idx] mid = self.model_ids[idx] iid = self.image_ids[idx] # Always read metadata for this model; TODO cache in __init__? metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RT = metadata["extrinsics"][iid] img_path = metadata["image_list"][iid] img_path = os.path.join(self.data_dir, sid, mid, "images", img_path) # Load the image with open(img_path, "rb") as f: img = Image.open(f).convert("RGB") img = self.transform(img) # Maybe read mesh verts, faces = None, None if self.return_mesh: mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt") mesh_data = torch.load(mesh_path) verts, faces = mesh_data["verts"], mesh_data["faces"] verts = project_verts(verts, RT) # Maybe use cached samples points, normals = None, None if not self.sample_online: samples = self.mid_to_samples.get(mid, None) if samples is None: # They were not cached in memory, so read off disk samples_path = os.path.join(self.data_dir, sid, mid, "samples.pt") samples = torch.load(samples_path) points = samples["points_sampled"] normals = samples["normals_sampled"] idx = torch.randperm(points.shape[0])[:self.num_samples] points, normals = points[idx], normals[idx] points = project_verts(points, RT) normals = normals.mm( RT[:3, :3].t()) # Only rotate, don't translate voxels, P = None, None if self.voxel_size > 0: # Use precomputed voxels if we have them, otherwise return voxel_coords # and we will compute voxels in postprocess voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid) voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file) if os.path.isfile(voxel_file): voxels = torch.load(voxel_file) else: voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt") voxel_data = torch.load(voxel_path) voxels = voxel_data["voxel_coords"] P = K.mm(RT) id_str = "%s-%s-%02d" % (sid, mid, iid) return img, verts, faces, points, normals, voxels, P, id_str
def voxelize(voxel_coords, P, V): device = voxel_coords.device voxel_coords = project_verts(voxel_coords, P) # Using the actual zmin and zmax of the model is bad because we need them # to perform the inverse transform, which transform voxels back into world # space for refinement or evaluation. Instead we use an empirical min and # max over the dataset; that way it is consistent for all images. zmin = SHAPENET_MIN_ZMIN zmax = SHAPENET_MAX_ZMAX # Once we know zmin and zmax, we need to adjust the z coordinates so the # range [zmin, zmax] instead runs from [-1, 1] m = 2.0 / (zmax - zmin) b = -2.0 * zmin / (zmax - zmin) - 1 voxel_coords[:, 2].mul_(m).add_(b) voxel_coords[:, 1].mul_(-1) # Flip y # Now voxels are in [-1, 1]^3; map to [0, V-1)^3 voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0) voxel_coords = voxel_coords.round().to(torch.int64) valid = (0 <= voxel_coords) * (voxel_coords < V) valid = valid[:, 0] * valid[:, 1] * valid[:, 2] x, y, z = voxel_coords.unbind(dim=1) x, y, z = x[valid], y[valid], z[valid] voxels = torch.zeros(V, V, V, dtype=torch.uint8, device=device) voxels[z, y, x] = 1 return voxels
def sample_points_normals(self, data_dir, sid, mid, RT): samples = self.mid_to_samples.get(mid, None) if samples is None: # They were not cached in memory, so read off disk samples_path = os.path.join(data_dir, sid, mid, "samples.pt") samples = torch.load(samples_path) points = samples["points_sampled"] normals = samples["normals_sampled"] idx = torch.randperm(points.shape[0])[: self.num_samples] points, normals = points[idx], normals[idx] points = project_verts(points, RT) normals = normals.mm(RT[:3, :3].t()) # Only rotate, don't translate return points, normals
def forward(self, img_feats, meshes, vert_feats=None, P=None): """ Args: img_feats (tensor): Features from the backbone meshes (Meshes): Initial meshes which will get refined vert_feats (tensor): Features from the previous refinement stage P (tensor): Tensor of shape (N, 4, 4) giving projection matrix to be applied to vertex positions before vert-align. If None, don't project verts. """ # Project verts if we are making predictions in world space verts_padded_to_packed_idx = meshes.verts_padded_to_packed_idx() if P is not None: vert_pos_padded = project_verts(meshes.verts_padded(), P) vert_pos_packed = _padded_to_packed(vert_pos_padded, verts_padded_to_packed_idx) else: vert_pos_padded = meshes.verts_padded() vert_pos_packed = meshes.verts_packed() # flip y coordinate device, dtype = vert_pos_padded.device, vert_pos_padded.dtype factor = torch.tensor([1, -1, 1], device=device, dtype=dtype).view(1, 1, 3) vert_pos_padded = vert_pos_padded * factor # Get features from the image # print(img_feats[0].shape) # print(vert_pos_padded.shape) vert_align_feats = vert_align(img_feats, vert_pos_padded) vert_align_feats = _padded_to_packed(vert_align_feats, verts_padded_to_packed_idx) vert_align_feats = F.relu(self.bottleneck(vert_align_feats)) # Prepare features for first graph conv layer first_layer_feats = [vert_align_feats, vert_pos_packed] if vert_feats is not None: first_layer_feats.append(vert_feats) vert_feats = torch.cat(first_layer_feats, dim=1) # Run graph conv layers for gconv in self.gconvs: vert_feats_nopos = F.relu(gconv(vert_feats, meshes.edges_packed())) vert_feats = torch.cat([vert_feats_nopos, vert_pos_packed], dim=1) # Predict a new mesh by offsetting verts vert_offsets = torch.tanh(self.vert_offset(vert_feats)) meshes_out = meshes.offset_verts(vert_offsets) return meshes_out, vert_feats_nopos
def _voxelize(self, voxel_coords, P): V = self.voxel_size device = voxel_coords.device voxel_coords = project_verts(voxel_coords, P) # In the original coordinate system, the object fits in a unit sphere # centered at the origin. Thus after transforming by RT, it will fit # in a unit sphere centered at T = RT[:, 3] = (0, 0, RT[2, 3]). We need # to figure out what the range of z will be after being further # transformed by K. We can work this out explicitly. # z0 = RT[2, 3].item() # zp, zm = z0 - 0.5, z0 + 0.5 # k22, k23 = K[2, 2].item(), K[2, 3].item() # k32, k33 = K[3, 2].item(), K[3, 3].item() # zmin = (zm * k22 + k23) / (zm * k32 + k33) # zmax = (zp * k22 + k23) / (zp * k32 + k33) # Using the actual zmin and zmax of the model is bad because we need them # to perform the inverse transform, which transform voxels back into world # space for refinement or evaluation. Instead we use an empirical min and # max over the dataset; that way it is consistent for all images. zmin = SHAPENET_MIN_ZMIN zmax = SHAPENET_MAX_ZMAX # Once we know zmin and zmax, we need to adjust the z coordinates so the # range [zmin, zmax] instead runs from [-1, 1] m = 2.0 / (zmax - zmin) b = -2.0 * zmin / (zmax - zmin) - 1 voxel_coords[:, 2].mul_(m).add_(b) voxel_coords[:, 1].mul_(-1) # Flip y # Now voxels are in [-1, 1]^3; map to [0, V-1)^3 voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0) voxel_coords = voxel_coords.round().to(torch.int64) valid = (0 <= voxel_coords) * (voxel_coords < V) valid = valid[:, 0] * valid[:, 1] * valid[:, 2] x, y, z = voxel_coords.unbind(dim=1) x, y, z = x[valid], y[valid], z[valid] voxels = torch.zeros(V, V, V, dtype=torch.int64, device=device) voxels[z, y, x] = 1 return voxels
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device, loss_fn): if comm.is_main_process(): wandb.init(project='MeshRCNN', config=cfg, name='meshrcnn') Timer.timing = False iteration_timer = Timer("Iteration") # model.parameters() is surprisingly expensive at 150ms, so cache it if hasattr(model, "module"): params = list(model.module.parameters()) else: params = list(model.parameters()) loss_moving_average = cp.data.get("loss_moving_average", None) # Zhengyuan modification loss_predictor = LossPredictionModule().to(device) loss_pred_optim = torch.optim.SGD(loss_predictor.parameters(), lr=1e-3, momentum=0.9) while cp.epoch < cfg.SOLVER.NUM_EPOCHS: if comm.is_main_process(): logger.info("Starting epoch %d / %d" % (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS)) # When using a DistributedSampler we need to manually set the epoch so that # the data is shuffled differently at each epoch for loader in loaders.values(): if hasattr(loader.sampler, "set_epoch"): loader.sampler.set_epoch(cp.epoch) # Config settings for renderer blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=256, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).float().to(device) for i, batch in enumerate(loaders["train"]): if i == 0: iteration_timer.start() else: iteration_timer.tick() batch = loaders["train"].postprocess(batch, device) if dataset == 'MeshVoxMulti': imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _imgs, render_RTs, RTs = batch else: imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch # NOTE: _imgs contains all of the other images in belonging to this model # We have to select the next-best-view from that list of images num_infinite_params = 0 for p in params: num_infinite_params += (torch.isfinite( p.data) == 0).sum().item() if num_infinite_params > 0: msg = "ERROR: Model has %d non-finite params (before forward!)" logger.info(msg % num_infinite_params) return model_kwargs = {} if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS: model_kwargs["voxel_only"] = True with Timer("Forward"): voxel_scores, meshes_pred = model(imgs, **model_kwargs) num_infinite = 0 for cur_meshes in meshes_pred: cur_verts = cur_meshes.verts_packed() num_infinite += (torch.isfinite(cur_verts) == 0).sum().item() if num_infinite > 0: logger.info("ERROR: Got %d non-finite verts" % num_infinite) return total_silh_loss = torch.tensor( 0.) # Total silhouette loss, to be added to "loss" below # Voxel only training for first few iterations if not meshes_gt is None and not model_kwargs.get( "voxel_only", False): import pdb pdb.set_trace() _meshes_pred = meshes_pred[-1].clone() _meshes_gt = meshes_gt[-1].clone() # Render masks from predicted mesh for each view # GT probability map to supervise prediction module B = len(meshes_gt) probability_map = 0.01 * torch.ones((B, 24)) # batch size x 24 for b, (cur_gt_mesh, cur_pred_mesh) in enumerate( zip(meshes_gt, _meshes_pred)): # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint RT = RTs[b] # Rotate 90 degrees about y-axis and invert invRT = torch.inverse(RT.mm(rot_y_90)) invRT_no_rot = torch.inverse(RT) # Just invert cur_pred_mesh._verts_list[0] = project_verts( cur_pred_mesh._verts_list[0], invRT) sid = id_strs[b].split('-')[0] # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis if sid == '02958343': cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT_no_rot) else: cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT) for iid in range(len(render_RTs[b])): R = render_RTs[b][iid][:3, :3].unsqueeze(0) T = render_RTs[b][iid][:3, 3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings), shader=SoftSilhouetteShader( blend_params=blend_params)) ref_image = silhouette_renderer( meshes_world=cur_gt_mesh, R=R, T=T) image = silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T) ''' import matplotlib.pyplot as plt plt.subplot(1,2,1) plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy()) plt.subplot(1,2,2) plt.imshow(image[0,:,:,3].detach().cpu().numpy()) plt.show() ''' # MSE Loss between both silhouettes silh_loss = torch.sum( (image[0, :, :, 3] - ref_image[0, :, :, 3])**2) probability_map[b, iid] = silh_loss.detach() total_silh_loss += silh_loss probability_map = torch.nn.functional.softmax( probability_map, dim=1) # Softmax across images nbv_idx = torch.argmax(probability_map, dim=1) # Next-best view indices nbv_imgs = _imgs[torch.arange(B), nbv_idx] # Next-best view images # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction # The input should be the first image and the next-best view #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs) loss, losses = None, {} if num_infinite == 0: loss, losses = loss_fn(voxel_scores, meshes_pred, voxels_gt, (points_gt, normals_gt)) skip = loss is None if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0: logger.info("WARNING: Got non-finite loss %f" % loss) skip = True # Add silhouette loss to total loss silh_weight = 1.0 # TODO: Add a weight for the silhouette loss? loss = loss + total_silh_loss * silh_weight losses['silhouette'] = total_silh_loss if model_kwargs.get("voxel_only", False): for k, v in losses.items(): if k != "voxel": losses[k] = 0.0 * v if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0: if comm.is_main_process(): cp.store_metric(loss=loss.item()) str_out = "Iteration: %d, epoch: %d, lr: %.5f," % ( cp.t, cp.epoch, optimizer.param_groups[0]["lr"], ) for k, v in losses.items(): str_out += " %s loss: %.4f," % (k, v.item()) str_out += " total loss: %.4f," % loss.item() # memory allocaged if torch.cuda.is_available(): max_mem_mb = torch.cuda.max_memory_allocated( ) / 1024.0 / 1024.0 str_out += " mem: %d" % max_mem_mb if len(meshes_pred) > 0: mean_V = meshes_pred[-1].num_verts_per_mesh().float( ).mean().item() mean_F = meshes_pred[-1].num_faces_per_mesh().float( ).mean().item() str_out += ", mesh size = (%d, %d)" % (mean_V, mean_F) logger.info(str_out) # Log with Weights & Biases, comment out if not installed wandb.log(losses) if loss_moving_average is None and loss is not None: loss_moving_average = loss.item() # Skip backprop for this batch if the loss is above the skip factor times # the moving average for losses if loss is None: pass elif loss.item( ) > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average: logger.info("Warning: Skipping loss %f on GPU %d" % (loss.item(), comm.get_rank())) cp.store_metric(losses_skipped=loss.item()) skip = True else: # Update the moving average of our loss gamma = cfg.SOLVER.LOSS_SKIP_GAMMA loss_moving_average *= gamma loss_moving_average += (1.0 - gamma) * loss.item() cp.store_data("loss_moving_average", loss_moving_average) if skip: logger.info("Dummy backprop on GPU %d" % comm.get_rank()) loss = 0.0 * sum(p.sum() for p in params) # Backprop and step scheduler.step() optimizer.zero_grad() with Timer("Backward"): loss.backward() # Zhengyuan step loss_prediction loss_predictor.train_batch(image, probability_map, loss_pred_optim) # When training with normal loss, sometimes I get NaNs in gradient that # cause the model to explode. Check for this before performing a gradient # update. This is safe in mult-GPU since gradients have already been # summed, so each GPU has the same gradients. num_infinite_grad = 0 for p in params: num_infinite_grad += (torch.isfinite(p.grad) == 0).sum().item() if num_infinite_grad == 0: optimizer.step() else: msg = "WARNING: Got %d non-finite elements in gradient; skipping update" logger.info(msg % num_infinite_grad) cp.step() if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0: eval_and_save(model, loaders, optimizer, scheduler, cp) cp.step_epoch() eval_and_save(model, loaders, optimizer, scheduler, cp) if comm.is_main_process(): logger.info("Evaluating on test set:") test_loader = build_data_loader(cfg, dataset, "test", multigpu=False) evaluate_test(model, test_loader)
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device, loss_fn): #if comm.is_main_process(): # wandb.init(project='MeshRCNN', config=cfg, name='prediction_module') Timer.timing = False iteration_timer = Timer("Iteration") # model.parameters() is surprisingly expensive at 150ms, so cache it if hasattr(model, "module"): params = list(model.module.parameters()) else: params = list(model.parameters()) loss_moving_average = cp.data.get("loss_moving_average", None) # Zhengyuan modification loss_predictor = LossPredictionModule().to(device) loss_pred_optim = torch.optim.Adam(loss_predictor.parameters(), lr=1e-5) while cp.epoch < cfg.SOLVER.NUM_EPOCHS: if comm.is_main_process(): logger.info("Starting epoch %d / %d" % (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS)) # When using a DistributedSampler we need to manually set the epoch so that # the data is shuffled differently at each epoch for loader in loaders.values(): if hasattr(loader.sampler, "set_epoch"): loader.sampler.set_epoch(cp.epoch) # Config settings for renderer render_image_size = 256 blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=render_image_size, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).float().to(device) for i, batch in enumerate(loaders["train"]): if i == 0: iteration_timer.start() else: iteration_timer.tick() batch = loaders["train"].postprocess(batch, device) if dataset == 'MeshVoxMulti': imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _, render_RTs, RTs = batch else: imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch with inference_context(model): # NOTE: _imgs contains all of the other images in belonging to this model # We have to select the next-best-view from that list of images model_kwargs = {} if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS: model_kwargs["voxel_only"] = True with Timer("Forward"): voxel_scores, meshes_pred = model(imgs, **model_kwargs) total_silh_loss = torch.tensor( 0.) # Total silhouette loss, to be added to "loss" below # Voxel only training for first few iterations if not meshes_gt is None and not model_kwargs.get( "voxel_only", False): _meshes_pred = meshes_pred[-1].clone() _meshes_gt = meshes_gt[-1].clone() # Render masks from predicted mesh for each view # GT probability map to supervise prediction module B = len(meshes_gt) probability_map = 0.01 * torch.ones( (B, 24)).to(device) # batch size x 24 viewgrid = torch.zeros( (B, 24, render_image_size, render_image_size)).to(device) # batch size x 24 x H x W for b, (cur_gt_mesh, cur_pred_mesh) in enumerate( zip(meshes_gt, _meshes_pred)): # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint RT = RTs[b] # Rotate 90 degrees about y-axis and invert invRT = torch.inverse(RT.mm(rot_y_90)) invRT_no_rot = torch.inverse(RT) # Just invert cur_pred_mesh._verts_list[0] = project_verts( cur_pred_mesh._verts_list[0], invRT) sid = id_strs[b].split('-')[0] # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis if sid == '02958343': cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT_no_rot) else: cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT) for iid in range(len(render_RTs[b])): R = render_RTs[b][iid][:3, :3].unsqueeze(0) T = render_RTs[b][iid][:3, 3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings), shader=SoftSilhouetteShader( blend_params=blend_params)) ref_image = (silhouette_renderer( meshes_world=cur_gt_mesh, R=R, T=T) > 0).float() image = (silhouette_renderer( meshes_world=cur_pred_mesh, R=R, T=T) > 0).float() #Add image silhouette to viewgrid viewgrid[b, iid] = image[..., -1] ''' import matplotlib.pyplot as plt plt.subplot(1,2,1) plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy()) plt.subplot(1,2,2) plt.imshow(image[0,:,:,3].detach().cpu().numpy()) plt.show() ''' # MSE Loss between both silhouettes silh_loss = torch.sum( (image[0, :, :, 3] - ref_image[0, :, :, 3])**2) probability_map[b, iid] = silh_loss.detach() total_silh_loss += silh_loss probability_map = probability_map / (torch.max( probability_map, dim=1)[0].unsqueeze(1)) # Normalize probability_map = torch.nn.functional.softmax( probability_map, dim=1).to(device) # Softmax across images #nbv_idx = torch.argmax(probability_map, dim=1) # Next-best view indices #nbv_imgs = _imgs[torch.arange(B), nbv_idx] # Next-best view images # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction # The input should be the first image and the next-best view #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs) # Zhengyuan step loss_prediction predictor_loss = loss_predictor.train_batch( viewgrid, probability_map, loss_pred_optim) if comm.is_main_process(): #wandb.log({'prediction module loss':predictor_loss}) if cp.t % 50 == 0: print('{} predictor_loss: {}'.format( cp.t, predictor_loss)) #Save checkpoint every t iteration if cp.t % 500 == 0: print( 'Saving loss prediction module at iter {}'.format( cp.t)) os.makedirs('./output_prediction_module', exist_ok=True) torch.save( loss_predictor.state_dict(), './output_prediction_module/prediction_module_' + str(cp.t) + '.pth') cp.step() if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0: eval_and_save(model, loaders, optimizer, scheduler, cp) cp.step_epoch() eval_and_save(model, loaders, optimizer, scheduler, cp) if comm.is_main_process(): logger.info("Evaluating on test set:") test_loader = build_data_loader(cfg, dataset, "test", multigpu=False) evaluate_test(model, test_loader)
def run_on_image(self, image, id_str, gt_verts, gt_faces): deprocess = imagenet_deprocess(rescale_image=False) with torch.no_grad(): voxel_scores, meshes_pred = self.predictor(image.to(self.device)) sid, mid, iid = id_str.split('-') iid = int(iid) #Transform vertex space metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed', sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RTs = metadata["extrinsics"].to(self.device) rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).to(RTs) mesh = meshes_pred[-1][0] #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis #for the GT mesh invRT = torch.inverse(RTs[iid].mm(rot_y_90)) invRT_no_rot = torch.inverse(RTs[iid]) mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT) #Get look at view extrinsics render_metadata_path = os.path.join( 'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid, 'rendering_metadata.pt') render_metadata = torch.load(render_metadata_path) render_RTs = render_metadata['extrinsics'].to(self.device) verts, faces = mesh.get_mesh_verts_faces(0) verts_rgb = torch.ones_like(verts)[None] textures = Textures(verts_rgb=verts_rgb.to(self.device)) mesh.textures = textures plt.figure(figsize=(10, 10)) #Silhouette Renderer render_image_size = 256 blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=render_image_size, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) gt_verts = gt_verts.to(self.device) gt_faces = gt_faces.to(self.device) verts_rgb = torch.ones_like(gt_verts)[None] textures = Textures(verts_rgb=verts_rgb) #Invert without the rotation for the vehicle class if sid == '02958343': gt_verts = project_verts(gt_verts, invRT_no_rot.to(self.device)) else: gt_verts = project_verts(gt_verts, invRT.to(self.device)) gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures) probability_map = 0.01 * torch.ones((1, 24)) viewgrid = torch.zeros( (1, 24, render_image_size, render_image_size)).to(self.device) fig = plt.figure(1) ax_pred = [fig.add_subplot(5, 5, i + 1) for i in range(24)] #fig = plt.figure(2) #ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)] for i in range(len(render_RTs)): if i == iid: #Don't include current view continue R = render_RTs[i][:3, :3].unsqueeze(0) T = render_RTs[i][:3, 3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=SoftSilhouetteShader(blend_params=blend_params)) ref_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) > 0).float() silhouette_image = (silhouette_renderer( meshes_world=mesh, R=R, T=T) > 0).float() # MSE Loss between both silhouettes silh_loss = torch.sum( (silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3])**2) probability_map[0, i] = silh_loss.detach() viewgrid[0, i] = silhouette_image[..., -1] #ax_gt[i].imshow(ref_image[0,:,:,3].cpu().numpy()) #ax_gt[i].set_title(i) ax_pred[i].imshow(silhouette_image[0, :, :, 3].cpu().numpy()) ax_pred[i].set_title(i) img = image_to_numpy(deprocess(image[0])) #ax_gt[iid].imshow(img) ax_pred[iid].imshow(img) #fig = plt.figure(3) #ax = fig.add_subplot(111) #ax.imshow(img) pred_prob_map = self.loss_predictor(viewgrid) print('Highest actual loss: {}'.format(torch.argmax(probability_map))) print('Highest predicted loss: {}'.format(torch.argmax(pred_prob_map))) plt.show()
def run_on_image(self, imgs, id_strs, meshes_gt, render_RTs, RTs): deprocess = imagenet_deprocess(rescale_image=False) #Silhouette Renderer render_image_size = 256 blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=render_image_size, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).to(RTs) with torch.no_grad(): voxel_scores, meshes_pred = self.predictor(imgs) B,_,H,W = imgs.shape probability_map = 0.01 * torch.ones((B, 24)).to(self.device) viewgrid = torch.zeros((B,24,render_image_size,render_image_size)).to(device) # batch size x 24 x H x W _meshes_pred = meshes_pred[-1] for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(zip(meshes_gt, _meshes_pred)): sid = id_strs[b].split('-')[0] RT = RTs[b] #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis #for the GT mesh invRT = torch.inverse(RT.mm(rot_y_90)) invRT_no_rot = torch.inverse(RT) cur_pred_mesh._verts_list[0] = project_verts(cur_pred_mesh._verts_list[0], invRT) #Invert without the rotation for the vehicle class if sid == '02958343': cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT_no_rot) else: cur_gt_mesh._verts_list[0] = project_verts( cur_gt_mesh._verts_list[0], invRT) ''' plt.figure(figsize=(10, 10)) fig = plt.figure(1) ax_pred = [fig.add_subplot(5,5,i+1) for i in range(24)] fig = plt.figure(2) ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)] ''' for iid in range(len(render_RTs)): R = render_RTs[b][iid][:3,:3].unsqueeze(0) T = render_RTs[b][iid][:3,3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=SoftSilhouetteShader(blend_params=blend_params) ) ref_image = (silhouette_renderer(meshes_world=cur_gt_mesh, R=R, T=T)>0).float() silhouette_image = (silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T)>0).float() #Add image silhouette to viewgrid viewgrid[b,iid] = silhouette_image[...,-1] # MSE Loss between both silhouettes silh_loss = torch.sum((silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3]) ** 2) probability_map[b, iid] = silh_loss.detach() ''' ax_gt[iid].imshow(ref_image[0,:,:,3].cpu().numpy()) ax_gt[iid].set_title(iid) ax_pred[iid].imshow(silhouette_image[0,:,:,3].cpu().numpy()) ax_pred[iid].set_title(iid) ''' ''' img = image_to_numpy(deprocess(imgs[b])) ax_gt[iid].imshow(img) ax_pred[iid].imshow(img) fig = plt.figure(3) ax = fig.add_subplot(111) ax.imshow(img) #plt.show() ''' probability_map = probability_map/(torch.max(probability_map, dim=1)[0].unsqueeze(1)) # Normalize probability_map = torch.nn.functional.softmax(probability_map, dim=1) # Softmax across images pred_prob_map = self.loss_predictor(viewgrid) gt_max = torch.argmax(probability_map, dim=1) pred_max = torch.argmax(pred_prob_map, dim=1) #print('--'*30) #print('Item: {}'.format(id_str)) #print('Highest actual loss: {}'.format(gt_max)) #print('Highest predicted loss: {}'.format(pred_max)) #print('GT prob map: {}'.format(probability_map.squeeze())) #print('Pred prob map: {}'.format(pred_prob_map.squeeze())) correct = torch.sum(pred_max == gt_max).item() total = len(pred_prob_map) return correct, total
def __getitem__(self, idx): sid = self.synset_ids[idx] mid = self.model_ids[idx] iid = self.image_ids[idx] # Always read metadata for this model; TODO cache in __init__? metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RT = metadata["extrinsics"][iid] img_path = metadata["image_list"][iid] img_path = os.path.join(self.data_dir, sid, mid, "images", img_path) # Load the image with open(img_path, "rb") as f: img = Image.open(f).convert("RGB") img = self.transform(img) # All other rendered images for this model. Needed for "viewgrid" _iids = [ self.image_ids[i] for i in self.mid_to_idx[sid + '_' + mid] if i != idx ] _imgs = [] for _iid in _iids: img_path = metadata["image_list"][_iid] img_path = os.path.join(self.data_dir, sid, mid, "images", img_path) with open(img_path, "rb") as f: _img = Image.open(f).convert("RGB") _imgs.append(self.transform(_img)) _imgs = torch.stack(_imgs) # N x C x H x W #Zero pad so N = 23. 24 images minus the current image N, C, H, W = _imgs.shape n = 23 d = 0 if len(_imgs) < n: d = n - len(_imgs) _imgs = torch.cat((_imgs, torch.zeros(d, C, H, W))) #Mask for zero-padded images. 1 for valid image, 0 for zero image mask = torch.cat((torch.ones(N), torch.zeros(d))) #Get transformation matrices to view renderings from same camera position render_metadata_path = os.path.join(self.rendering_dir, sid, mid, 'rendering_metadata.pt') render_metadata = torch.load(render_metadata_path) render_RTs = render_metadata['extrinsics'] #Only keep matrices for all "other" images idxs = torch.arange(n + 1) keep = (idxs != iid) render_RTs = render_RTs[keep] # N x 3 x 4 # Maybe read mesh verts, faces = None, None if self.return_mesh: mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt") mesh_data = torch.load(mesh_path) verts, faces = mesh_data["verts"], mesh_data["faces"] verts = project_verts(verts, RT) # Maybe use cached samples points, normals = None, None if not self.sample_online: samples = self.mid_to_samples.get(mid, None) if samples is None: # They were not cached in memory, so read off disk samples_path = os.path.join(self.data_dir, sid, mid, "samples.pt") samples = torch.load(samples_path) points = samples["points_sampled"] normals = samples["normals_sampled"] idx = torch.randperm(points.shape[0])[:self.num_samples] points, normals = points[idx], normals[idx] points = project_verts(points, RT) normals = normals.mm( RT[:3, :3].t()) # Only rotate, don't translate voxels, P = None, None if self.voxel_size > 0: # Use precomputed voxels if we have them, otherwise return voxel_coords # and we will compute voxels in postprocess voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid) voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file) if os.path.isfile(voxel_file): voxels = torch.load(voxel_file) else: voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt") voxel_data = torch.load(voxel_path) voxels = voxel_data["voxel_coords"] P = K.mm(RT) id_str = "%s-%s-%02d" % (sid, mid, iid) return img, verts, faces, points, normals, voxels, P, id_str, _imgs, render_RTs, RT
def read_mesh(data_dir, sid, mid, RT): mesh_path = os.path.join(data_dir, sid, mid, "mesh.pt") mesh_data = torch.load(mesh_path) verts, faces = mesh_data["verts"], mesh_data["faces"] verts = project_verts(verts, RT) return verts, faces
def run_on_image(self, image, id_str, gt_verts, gt_faces): deprocess = imagenet_deprocess(rescale_image=False) with torch.no_grad(): voxel_scores, meshes_pred = self.predictor(image) sid, mid, iid = id_str.split('-') iid = int(iid) #Transform vertex space metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed', sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RTs = metadata["extrinsics"] rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).to(RTs) mesh = meshes_pred[-1][0] #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis #for the GT mesh invRT = torch.inverse(RTs[iid].mm(rot_y_90)) invRT_no_rot = torch.inverse(RTs[iid]) mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT.cpu()) #Get look at view extrinsics render_metadata_path = os.path.join( 'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid, 'rendering_metadata.pt') render_metadata = torch.load(render_metadata_path) render_RTs = render_metadata['extrinsics'] plt.figure(figsize=(10, 10)) R = render_RTs[iid][:3, :3].unsqueeze(0) T = render_RTs[iid][:3, 3].unsqueeze(0) cameras = OpenGLPerspectiveCameras(R=R, T=T) #Phong Renderer lights = PointLights(location=[[0.0, 0.0, -3.0]]) raster_settings = RasterizationSettings(image_size=256, blur_radius=0.0, faces_per_pixel=1, bin_size=0) phong_renderer = MeshRenderer(rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings), shader=HardPhongShader(lights=lights)) #Silhouette Renderer blend_params = BlendParams(sigma=1e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=256, blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, faces_per_pixel=50, ) silhouette_renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=SoftSilhouetteShader(blend_params=blend_params)) verts, faces = mesh.get_mesh_verts_faces(0) verts_rgb = torch.ones_like(verts)[None] textures = Textures(verts_rgb=verts_rgb) mesh.textures = textures verts_rgb = torch.ones_like(gt_verts)[None] textures = Textures(verts_rgb=verts_rgb) #Invert without the rotation for the vehicle class if sid == '02958343': gt_verts = project_verts(gt_verts, invRT_no_rot.cpu()) else: gt_verts = project_verts(gt_verts, invRT.cpu()) gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures) img = image_to_numpy(deprocess(image[0])) mesh_image = phong_renderer(meshes_world=mesh, R=R, T=T) gt_silh_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) > 0).float() silhouette_image = (silhouette_renderer(meshes_world=mesh, R=R, T=T) > 0).float() plt.subplot(2, 2, 1) plt.imshow(img) plt.title('input image') plt.subplot(2, 2, 2) plt.imshow(mesh_image[0, ..., :3].cpu().numpy()) plt.title('rendered mesh') plt.subplot(2, 2, 3) plt.imshow(gt_silh_image[0, ..., 3].cpu().numpy()) plt.title('silhouette of gt mesh') plt.subplot(2, 2, 4) plt.imshow(silhouette_image[0, ..., 3].cpu().numpy()) plt.title('silhouette of rendered mesh') plt.show() #plt.savefig('./output_demo/figures/'+id_str+'.png') vis_utils.visualize_prediction(id_str, img, mesh, self.output_dir)
verts, faces_idx, _ = load_obj(obj_filename) verts, faces_idx, _ = load_obj(obj_filename) faces = faces_idx.verts_idx ''' ##########Predicted Mesh #Generate this by running: python demo/demo_modified.py --config-file configs/shapenet/voxmesh_R50.yaml --data_dir datasets/shapenet/ShapeNetV1processed --output output_demo --checkpoint shapenet://voxmesh_R50.pth --index 0 obj_filename = './output_demo/results_shapenet/02958343-4856ef1e80d356d111f983eb293b51a-00.obj' verts, faces_idx, _ = load_obj(obj_filename) faces = faces_idx.verts_idx invRT = torch.inverse(RTs[0].mm(rot_y_90)) #invRT = torch.inverse(RTs[0].mm(rot_x_n90)) #invRT = torch.inverse(RTs[0]) verts = project_verts(verts, invRT.cpu()) ##########GT Mesh ''' mesh_path = os.path.join('./datasets/shapenet/ShapeNetV1processed', sid, mid, "mesh.pt") mesh_data = torch.load(mesh_path) verts, faces = mesh_data["verts"], mesh_data["faces"] verts = project_verts(verts, RTs[0].cpu()) ''' verts_rgb = torch.ones_like(verts)[None] textures = Textures(verts_rgb=verts_rgb.to(device)) # print(verts_rgb.shape, verts.shape) mesh = Meshes( verts=[verts.to(device)],
def __getitem__(self, idx): sid = self.synset_ids[idx] mid = self.model_ids[idx] iid = self.image_ids[idx] # Always read metadata for this model; TODO cache in __init__? metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt") metadata = torch.load(metadata_path) K = metadata["intrinsic"] RT = metadata["extrinsics"][iid] img_path = metadata["image_list"][iid] img_path = os.path.join(self.data_dir, sid, mid, "images", img_path) # Load the image with open(img_path, "rb") as f: img = Image.open(f).convert("RGB") img = self.transform(img) # All images for this model # CAN add a variable to control the number of images _iids = [self.image_ids[i] for i in self.mid_to_idx[mid]] _imgs = [] for _iid in _iids: img_path = metadata["image_list"][_iid] img_path = os.path.join(self.data_dir, sid, mid, "images", img_path) with open(img_path, "rb") as f: _img = Image.open(f).convert("RGB") _imgs.append(self.transform(_img)) _imgs = torch.stack(_imgs) # tensor([N,3,224,224]) ## # N,C,H,W = _imgs.shape # if N != 24: # n = 24-N # padding = torch.zeros((n,C,H,W),dtype = _imgs.dtype, device = _imgs.device) # _imgs = torch.cat(_imgs, padding) ## # Maybe read mesh verts, faces = None, None if self.return_mesh: mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt") mesh_data = torch.load(mesh_path) verts, faces = mesh_data["verts"], mesh_data["faces"] verts = project_verts(verts, RT) # Maybe use cached samples points, normals = None, None if not self.sample_online: samples = self.mid_to_samples.get(mid, None) if samples is None: # They were not cached in memory, so read off disk samples_path = os.path.join(self.data_dir, sid, mid, "samples.pt") samples = torch.load(samples_path) points = samples["points_sampled"] normals = samples["normals_sampled"] idx = torch.randperm(points.shape[0])[:self.num_samples] points, normals = points[idx], normals[idx] points = project_verts(points, RT) normals = normals.mm( RT[:3, :3].t()) # Only rotate, don't translate voxels, P = None, None if self.voxel_size > 0: # Use precomputed voxels if we have them, otherwise return voxel_coords # and we will compute voxels in postprocess voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid) voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file) if os.path.isfile(voxel_file): voxels = torch.load(voxel_file) else: voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt") voxel_data = torch.load(voxel_path) voxels = voxel_data["voxel_coords"] P = K.mm(RT) id_str = "%s-%s-%02d" % (sid, mid, iid) return img, verts, faces, points, normals, voxels, P, id_str, _imgs