def generate_mesh(self, data, fix_normals=False): ''' Generates a mesh. Arguments: data (tensor): input data fix_normals (boolean): if normals should be fixed ''' img = data.get('inputs').to(self.device) camera_args = common.get_camera_args( data, 'pointcloud.loc', 'pointcloud.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] with torch.no_grad(): outputs1, outputs2 = self.model(img, camera_mat) out_1, out_2, out_3 = outputs1 transformed_pred = common.transform_points_back(out_3, world_mat) vertices = transformed_pred.squeeze().cpu().numpy() faces = self.base_mesh[:, 1:] # remove the f's in the first column faces = faces.astype(int) - 1 # To adjust indices to trimesh notation mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) if fix_normals: # Fix normals due to wrong base ellipsoid trimesh.repair.fix_normals(mesh) return mesh
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device p = data.get('points').to(device) if self.binary_occ: occ = (data.get('points.occ') >= 0.5).float().to(device) else: occ = data.get('points.occ').to(device) inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device) kwargs = {} if self.use_local_feature: camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device) Rt = camera_args['Rt'] K = camera_args['K'] f3,f2,f1 = self.model.encode_inputs(inputs,p,Rt,K) else: f3,f2,f1 = self.model.encode_inputs(inputs) q_z = self.model.infer_z(p, occ, f3, **kwargs) z = q_z.rsample() # KL-divergence kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean() # General points p_r = self.model.decode(p, z, f3, f2, f1, **kwargs) logits = p_r.logits probs = p_r.probs if self.loss_type == 'cross_entropy': loss_i = F.binary_cross_entropy_with_logits( logits, occ, reduction='none') elif self.loss_type == 'l2': logits = F.sigmoid(logits) loss_i = torch.pow((logits - occ), 2) elif self.loss_type == 'l1': logits = F.sigmoid(logits) loss_i = torch.abs(logits - occ) else: logits = F.sigmoid(logits) loss_i = F.binary_cross_entropy(logits, occ, reduction='none') if self.loss_tolerance_episolon != 0.: loss_i = torch.clamp(loss_i, min=self.loss_tolerance_episolon, max=100) if self.sign_lambda != 0.: w = 1. - self.sign_lambda * torch.sign(occ - 0.5) * torch.sign(probs - self.threshold) loss_i = loss_i * w if self.surface_loss_weight != 1.: w = ((occ > 0.) & (occ < 1.)).float() w = w * (self.surface_loss_weight - 1) + 1 loss_i = loss_i * w loss = loss + loss_i.sum(-1).mean() return loss
def train_step(self, data): r''' Performs a training step of the model. Arguments: data (tensor): The input data ''' self.model.train() points = data.get('pointcloud').to(self.device) normals = data.get('pointcloud.normals').to(self.device) img = data.get('inputs').to(self.device) camera_args = common.get_camera_args(data, 'pointcloud.loc', 'pointcloud.scale', device=self.device) # Transform GT data into camera coordinate system world_mat, camera_mat = camera_args['Rt'], camera_args['K'] points_transformed = common.transform_points(points, world_mat) # Transform GT normals to camera coordinate system world_normal_mat = world_mat[:, :, :3] normals = common.transform_points(normals, world_normal_mat) outputs1, outputs2 = self.model(img, camera_mat) loss = self.compute_loss(outputs1, outputs2, points_transformed, normals, img) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item()
def visualize(self, data): r''' Visualises the GT point cloud and predicted vertices (as a point cloud). Arguments: data (tensor): input data ''' points_gt = data.get('pointcloud').to(self.device) img = data.get('inputs').to(self.device) camera_args = common.get_camera_args(data, 'pointcloud.loc', 'pointcloud.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] if not os.path.isdir(self.vis_dir): os.mkdir(self.vis_dir) with torch.no_grad(): outputs1, outputs2 = self.model(img, camera_mat) pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1 points_out = common.transform_points_back(pred_vertices_3, world_mat) points_out = points_out.cpu().numpy() input_img_path = os.path.join(self.vis_dir, 'input.png') save_image(img.cpu(), input_img_path, nrow=4) points_gt = points_gt.cpu().numpy() batch_size = img.size(0) for i in range(batch_size): out_file = os.path.join(self.vis_dir, '%03d.png' % i) out_file_gt = os.path.join(self.vis_dir, '%03d_gt.png' % i) vis.visualize_pointcloud(points_out[i], out_file=out_file) vis.visualize_pointcloud(points_gt[i], out_file=out_file_gt)
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' npp = 0.1 device = self.device p = data.get('points').to(device) occ = data.get('points.occ').to(device) inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device) inputs = self.colornoise(inputs, npp) world_mat = data.get('inputs.world_mat').to(device) camera_mat = data.get('inputs.camera_mat').to(device) camera_args = common.get_camera_args(data, 'points.loc', 'points.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] self.vis(False, inputs[0].cpu().numpy()) # print("world_mat",world_mat.shape) # print("camera_mat",camera_mat.shape) # exit(1) kwargs = {} G, c = self.model.encode_inputs(inputs) # print("c0",c[0].shape) 64, 56, 56 # print("c1",c[1].shape) 128, 28, 28 # print("c2",c[2].shape) 256, 14, 14 # print("c3",c[3].shape) 512, 7, 7 # print("c4",c[4].shape) 256, 2, 2 # print("G",G.shape) 1024 v = self.model.gproj(p, G, c, world_mat, camera_mat, inputs, False) # v = self.model.gproj(p, c, camera_mat, inputs, True) # v point number, 1219 # v+G point number, 2243 q_z = self.model.infer_z(p, occ, c, **kwargs) z = q_z.rsample() # KL-divergence kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean() # General points logits = self.model.decode(p, z, v, **kwargs).logits # exit(1) loss_i = F.binary_cross_entropy_with_logits(logits, occ, reduction='none') loss = loss + loss_i.sum(-1).mean() return loss
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device p = data.get('points').to(device) occ = data.get('points.occ').to(device) inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device) I = data.get('inputs.image', torch.empty(p.size(0), 0)).to(device) world_mat = data.get('inputs.world_mat').to(device) camera_mat = data.get('inputs.camera_mat').to(device) camera_args = common.get_camera_args(data, 'points.loc', 'points.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] # print("world_mat",world_mat.shape) # print("camera_mat",camera_mat.shape) # exit(1) kwargs = {} c = self.model.encode_inputs(inputs) # print("c",c[0].shape) # print("c",c[1].shape) # print("c",c[2].shape) # print("c",c[3].shape) v = self.model.gproj(p, c, world_mat, camera_mat, inputs, False) # v = self.model.gproj(p, c, camera_mat, inputs, True) q_z = self.model.infer_z(p, occ, c, **kwargs) z = q_z.rsample() # KL-divergence kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean() # General points logits = self.model.decode(p, z, v, **kwargs).logits # exit(1) loss_i = F.binary_cross_entropy_with_logits(logits, occ, reduction='none') loss = loss + loss_i.sum(-1).mean() return loss
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device p = data.get('points').to(device) if self.binary_occ: occ = (data.get('points.occ') >= 0.5).float().to(device) else: occ = data.get('points.occ').to(device) inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device) kwargs = {} if self.use_local_feature: camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device) Rt = camera_args['Rt'] K = camera_args['K'] f3, f2, f1 = self.model.encode_inputs(inputs, p, Rt, K) else: f3, f2, f1 = self.model.encode_inputs(inputs) q_z = self.model.infer_z(p, occ, f3, **kwargs) z = q_z.rsample() # KL-divergence kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean() # General points p_r = self.model.decode(p, z, f3, f2, f1, **kwargs) logits = p_r.logits probs = p_r.probs # loss loss_i = get_occ_loss(logits, occ, self.loss_type) # loss strategies loss_i = occ_loss_postprocess(loss_i, occ, probs, self.loss_tolerance_episolon, self.sign_lambda, self.threshold, self.surface_loss_weight) loss = loss + loss_i.sum(-1).mean() return loss
def generate_mesh(self, data, return_stats=True): ''' Generates the output mesh. Args: data (tensor): data tensor return_stats (bool): whether stats should be returned ''' self.model.eval() device = self.device stats_dict = {} world_mat = data.get('inputs.world_mat').to(device) camera_mat = data.get('inputs.camera_mat').to(device) camera_args = common.get_camera_args(data, 'points.loc', 'points.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] inputs = data.get('inputs', torch.empty(1, 0)).to(device) kwargs = {} # Preprocess if requires if self.preprocessor is not None: t0 = time.time() with torch.no_grad(): inputs = self.preprocessor(inputs) stats_dict['time (preprocess)'] = time.time() - t0 # Encode inputs t0 = time.time() with torch.no_grad(): G, c = self.model.encode_inputs(inputs) stats_dict['time (encode inputs)'] = time.time() - t0 z = self.model.get_z_from_prior((1, ), sample=self.sample).to(device) mesh = self.generate_from_latent(z, G, c, stats_dict=stats_dict, world_mat=world_mat, camera_mat=camera_mat, **kwargs) if return_stats: return mesh, stats_dict else: return mesh
def eval_step(self, data): r''' Performs an evaluation step. Arguments: data (tensor): input data ''' self.model.eval() points = data.get('pointcloud').to(self.device) img = data.get('inputs').to(self.device) normals = data.get('pointcloud.normals').to(self.device) # Transform GT points to camera coordinates camera_args = common.get_camera_args(data, 'pointcloud.loc', 'pointcloud.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] points_transformed = common.transform_points(points, world_mat) # Transform GT normals to camera coordinates world_normal_mat = world_mat[:, :, :3] normals = common.transform_points(normals, world_normal_mat) with torch.no_grad(): outputs1, outputs2 = self.model(img, camera_mat) pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1 loss = self.compute_loss(outputs1, outputs2, points_transformed, normals, img) lc1, lc2, id31, id32 = chamfer_distance(pred_vertices_3, points_transformed, give_id=True) l_c = (lc1 + lc2).mean() l_e = self.edge_length_loss(pred_vertices_3, 3) l_n = self.normal_loss(pred_vertices_3, normals, id31, 3) l_l, move_loss = self.laplacian_loss(pred_vertices_3, outputs2[2], block_id=3) eval_dict = { 'loss': loss.item(), 'chamfer': l_c.item(), 'edge': l_e.item(), 'normal': l_n.item(), 'laplace': l_l.item(), 'move': move_loss.item() } return eval_dict
def visualize(self, data): ''' Performs a visualization step for the data. Args: data (dict): data dictionary ''' self.model.eval() device = self.device batch_size = data['points'].size(0) inputs = data.get('inputs', torch.empty(batch_size, 0)).to(device) if self.use_local_feature: camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device) Rt = camera_args['Rt'] K = camera_args['K'] shape = (32, 32, 32) p = make_3d_grid([-0.5] * 3, [0.5] * 3, shape).to(device) p = p.expand(batch_size, *p.size()) kwargs = {} with torch.no_grad(): if self.use_local_feature: p_r = self.model(p, inputs, Rt, K, sample=self.eval_sample, **kwargs) else: p_r = self.model(p, inputs, sample=self.eval_sample, **kwargs) occ_hat = p_r.probs.view(batch_size, *shape) voxels_out = (occ_hat >= self.threshold).cpu().numpy() for i in trange(batch_size): input_img_path = os.path.join(self.vis_dir, '%03d_in.png' % i) vis.visualize_data(inputs[i].cpu(), self.input_type, input_img_path) vis.visualize_voxels(voxels_out[i], os.path.join(self.vis_dir, '%03d.png' % i))
def generate_pointcloud(self, data): ''' Generates a pointcloud by only returning the vertices Arguments: data (tensor): input data ''' img = data.get('inputs').to(self.device) camera_args = common.get_camera_args( data, 'pointcloud.loc', 'pointcloud.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] with torch.no_grad(): outputs1, _ = self.model(img, camera_mat) _, _, out_3 = outputs1 transformed_pred = common.transform_points_back(out_3, world_mat) pc_out = transformed_pred.squeeze().cpu().numpy() return pc_out
def eval_step(self, data): ''' Performs an evaluation step. Args: data (dict): data dictionary ''' self.model.eval() device = self.device threshold = self.threshold eval_dict = {} # Compute elbo points = data.get('points').to(device) occ = data.get('points.occ').to(device) inputs = data.get('inputs', torch.empty(points.size(0), 0)).to(device) voxels_occ = data.get('voxels') points_iou = data.get('points_iou').to(device) occ_iou = data.get('points_iou.occ').to(device) world_mat = data.get('inputs.world_mat').to(device) camera_mat = data.get('inputs.camera_mat').to(device) camera_args = common.get_camera_args(data, 'points.loc', 'points.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] kwargs = {} with torch.no_grad(): elbo, rec_error, kl = self.model.compute_elbo( points, occ, inputs, world_mat, camera_mat, **kwargs) eval_dict['loss'] = -elbo.mean().item() eval_dict['rec_error'] = rec_error.mean().item() eval_dict['kl'] = kl.mean().item() # Compute iou batch_size = points.size(0) with torch.no_grad(): p_out = self.model(points_iou, inputs, world_mat, camera_mat, sample=self.eval_sample, **kwargs) occ_iou_np = (occ_iou >= 0.5).cpu().numpy() occ_iou_hat_np = (p_out.probs >= threshold).cpu().numpy() iou = compute_iou(occ_iou_np, occ_iou_hat_np).mean() eval_dict['iou'] = iou # Estimate voxel iou if voxels_occ is not None: voxels_occ = voxels_occ.to(device) points_voxels = make_3d_grid((-0.5 + 1 / 64, ) * 3, (0.5 - 1 / 64, ) * 3, (32, ) * 3) points_voxels = points_voxels.expand(batch_size, *points_voxels.size()) points_voxels = points_voxels.to(device) with torch.no_grad(): p_out = self.model(points_voxels, inputs, world_mat, camera_mat, sample=self.eval_sample, **kwargs) voxels_occ_np = (voxels_occ >= 0.5).cpu().numpy() occ_hat_np = (p_out.probs >= threshold).cpu().numpy() iou_voxels = compute_iou(voxels_occ_np, occ_hat_np).mean() eval_dict['iou_voxels'] = iou_voxels return eval_dict
def compose_inputs(data, mode='train', device=None, input_type='depth_pred', use_gt_depth_map=False, depth_map_mix=False, with_img=False, depth_pointcloud_transfer=None, local=False): assert mode in ('train', 'val', 'test') raw_data = {} if input_type == 'depth_pred': gt_mask = data.get('inputs.mask').to(device).byte() raw_data['mask'] = gt_mask batch_size = gt_mask.size(0) if use_gt_depth_map: gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) encoder_inputs = gt_depth_maps raw_data['depth'] = gt_depth_maps else: pr_depth_maps = data.get('inputs.depth_pred').to(device) background_setting(pr_depth_maps, gt_mask) raw_data['depth_pred'] = pr_depth_maps if depth_map_mix and mode == 'train': gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) raw_data['depth'] = gt_depth_maps alpha = torch.rand(batch_size, 1, 1, 1).to(device) pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 - alpha) encoder_inputs = pr_depth_maps if with_img: img = data.get('inputs').to(device) raw_data[None] = img encoder_inputs = {'img': img, 'depth': encoder_inputs} if local: camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device) Rt = camera_args['Rt'] K = camera_args['K'] encoder_inputs = { None: encoder_inputs, 'world_mat': Rt, 'camera_mat': K, } raw_data['world_mat'] = Rt raw_data['camera_mat'] = K return encoder_inputs, raw_data elif input_type == 'depth_pointcloud': encoder_inputs = data.get('inputs.depth_pointcloud').to(device) if depth_pointcloud_transfer is not None: if depth_pointcloud_transfer in ('world', 'world_scale_model'): encoder_inputs = encoder_inputs[:, :, [1, 0, 2]] world_mat = get_world_mat(data, transpose=None, device=device) raw_data['world_mat'] = world_mat R = world_mat[:, :, :3] # R's inverse is R^T encoder_inputs = transform_points(encoder_inputs, R.transpose(1, 2)) # or encoder_inputs = transform_points_back(encoder_inputs, R) if depth_pointcloud_transfer == 'world_scale_model': t = world_mat[:, :, 3:] encoder_inputs = encoder_inputs * t[:, 2:, :] elif depth_pointcloud_transfer in ('view', 'view_scale_model'): encoder_inputs = encoder_inputs[:, :, [1, 0, 2]] if depth_pointcloud_transfer == 'view_scale_model': world_mat = get_world_mat(data, transpose=None, device=device) raw_data['world_mat'] = world_mat t = world_mat[:, :, 3:] encoder_inputs = encoder_inputs * t[:, 2:, :] else: raise NotImplementedError raw_data['depth_pointcloud'] = encoder_inputs if local: #assert depth_pointcloud_transfer.startswith('world') encoder_inputs = {None: encoder_inputs} return encoder_inputs, raw_data else: raise NotImplementedError