def visualise_projection(self, points, world_mat, camera_mat, img, output_file='out.png'): r''' Visualizes the transformation and projection to image plane. The first points of the batch are transformed and projected to the respective image. After performing the relevant transformations, the visualization is saved in the provided output_file path. Arguments: points (tensor): batch of point cloud points world_mat (tensor): batch of matrices to rotate pc to camera-based coordinates camera_mat (tensor): batch of camera matrices to project to 2D image plane img (tensor): tensor of batch GT image files output_file (string): where the output should be saved ''' points_transformed = common.transform_points(points, world_mat) points_img = common.project_to_camera(points_transformed, camera_mat) pimg2 = points_img[0].detach().cpu().numpy() image = img[0].cpu().numpy() plt.imshow(image.transpose(1, 2, 0)) plt.plot((pimg2[:, 0] + 1) * image.shape[1] / 2, (pimg2[:, 1] + 1) * image.shape[2] / 2, 'x') plt.savefig(output_file)
def forward_local_second_step(self, data, c, local_feat_maps, pts): world_mat = data['world_mat'] camera_mat = data['camera_mat'] assert self.local pts = common.transform_points(pts, world_mat) points_img = common.project_to_camera(pts, camera_mat) points_img = points_img.unsqueeze(1) # get local feats local_feats = [] for f in local_feat_maps: #f = f.detach() f = F.grid_sample(f, points_img, mode='nearest') f = f.squeeze(2) local_feats.append(f) local_feats = torch.cat(local_feats, dim=1) local_feats = local_feats.transpose(1, 2) # batch * n_pts * f_dim local_feats = self.local_fc(local_feats) # x: B * c_dim # local: feats B * n_pts * c_dim return c, local_feats
def encode_second_step(self, f3, f2, f1, pts, world_mat, camera_mat): pts = common.transform_points(pts, world_mat) points_img = common.project_to_camera(pts, camera_mat) points_img = points_img.unsqueeze(1) f2 = f2.detach() f2 = F.relu(f2) f2 = F.grid_sample(f2, points_img) f2 = f2.squeeze(2) f2 = self.f2_conv(f2) f1 = f1.detach() f1 = F.relu(f1) f1 = F.grid_sample(f1, points_img) f1 = f1.squeeze(2) f1 = self.f1_conv(f1) f3 = self.fc3(f3) if self.batch_norm: f3 = self.f3_bn(f3) f2 = self.f2_bn(f2) f1 = self.f1_bn(f1) f2 = f2.transpose(1, 2) f1 = f1.transpose(1, 2) # f2 : batch * n_pts * fmap_dim # f1 : batch * n_pts * fmap_dim return f3, f2, f1
def forward(self, x, fm, camera_mat, img=None, visualise=False): ''' Performs a forward pass through the GP layer. Args: x (tensor): coordinates of shape (batch_size, num_vertices, 3) f (list): list of feature maps from where the image features should be pooled camera_mat (tensor): camera matrices for transformation to 2D image plane img (tensor): images (just fo visualisation purposes) ''' points_img = common.project_to_camera(x, camera_mat) points_img = points_img.unsqueeze(1) feats = [] feats.append(x) for fmap in fm: # bilinearly interpolate to get the corresponding features feat_pts = F.grid_sample(fmap, points_img) feat_pts = feat_pts.squeeze(2) feats.append(feat_pts.transpose(1, 2)) # Just for visualisation purposes if visualise and (img is not None): self.visualise_projection( points_img.squeeze(1)[0].detach().cpu().numpy(), img[0].cpu().numpy()) outputs = torch.cat([proj for proj in feats], dim=2) return outputs
def forward_local(self, data, pts): assert self.local world_mat = data['world_mat'] camera_mat = data['camera_mat'] x = data[None] pts = transform_points(pts, world_mat) points_img = project_to_camera(pts, camera_mat) points_img = points_img.unsqueeze(1) local_feat_maps = [] if self.normalize: x = normalize_imagenet(x) x = self.features.conv1(x) x = self.features.bn1(x) x = self.features.relu(x) x = self.features.maxpool(x) # 64 * 112 * 112 x = self.features.layer1(x) local_feat_maps.append(x) # 64 * 56 * 56 x = self.features.layer2(x) local_feat_maps.append(x) # 128 * 28 * 28 x = self.features.layer3(x) local_feat_maps.append(x) # 256 * 14 * 14 x = self.features.layer4(x) local_feat_maps.append(x) # 512 * 7 * 7 x = self.features.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) # get local feats local_feats = [] for f in local_feat_maps: #f = f.detach() f = F.grid_sample(f, points_img, mode='nearest') f = f.squeeze(2) local_feats.append(f) local_feats = torch.cat(local_feats, dim=1) local_feats = local_feats.transpose(1, 2) # batch * n_pts * f_dim local_feats = self.local_fc(local_feats) # x: B * c_dim # local: feats B * n_pts * c_dim return x, local_feats
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer, ) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) loss = 0 if isinstance(out, tuple): out, trans_feat = out if isinstance(self.model.encoder, PointNetEncoder) or isinstance( self.model.encoder, PointNetResEncoder): loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat) # chamfer distance loss if self.loss_type == 'cd': dist1, dist2 = cd.chamfer_distance(out, gt_pc) loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2. else: out_pts_count = out.size(1) loss = loss + (emd.earth_mover_distance( out, gt_pc, transpose=False) / out_pts_count).mean() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow, inplace=True).mean() loss = loss + self.loss_mask_flow_ratio * loss_mask_flow if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale # absolute depth from view background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = loss + self.loss_depth_test_ratio * loss_depth_test return loss
def eval_step(self, data): ''' Performs an evaluation step. Args: data (dict): data dictionary ''' self.model.eval() device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) batch_size = gt_pc.size(0) with torch.no_grad(): if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) if isinstance(out, tuple): out, trans_feat = out eval_dict = {} if batch_size == 1: pointcloud_hat = out.cpu().squeeze(0).numpy() pointcloud_gt = gt_pc.cpu().squeeze(0).numpy() eval_dict = self.mesh_evaluator.eval_pointcloud( pointcloud_hat, pointcloud_gt) # chamfer distance loss if self.loss_type == 'cd': dist1, dist2 = cd.chamfer_distance(out, gt_pc) loss = (dist1.mean(1) + dist2.mean(1)) / 2. else: loss = emd.earth_mover_distance(out, gt_pc, transpose=False) if self.gt_pointcloud_transfer in ('world_scale_model', 'view_scale_model', 'view'): pointcloud_scale = data.get('pointcloud.scale').to( device).view(batch_size, 1, 1) loss = loss / (pointcloud_scale**2) if self.gt_pointcloud_transfer == 'view': if world_mat is None: world_mat = get_world_mat(data, device=device) t_scale = world_mat[:, 2:, 3:] loss = loss * (t_scale**2) if self.loss_type == 'cd': loss = loss.mean() eval_dict['chamfer'] = loss.item() else: out_pts_count = out.size(1) loss = (loss / out_pts_count).mean() eval_dict['emd'] = loss.item() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow, inplace=True).mean() loss = self.loss_mask_flow_ratio * loss_mask_flow eval_dict['loss_mask_flow'] = loss.item() if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps = 0.05 loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = self.loss_depth_test_ratio * loss_depth_test eval_dict['loss_depth_test'] = loss.item() return eval_dict