def compute_sem_dis_loss(self, mesh, num_render, semantic_discriminator_net, device): # need to make sure this matches render settings for discriminator training set # TODO: alternatively, randomize the angles each time? # 0., 45., 90., 135., 180., 225., 270., 315. azims = torch.linspace(0, 360, num_render + 1)[:-1] elevs = torch.Tensor([25 for i in range(num_render)]) dists = torch.ones(num_render) * 1.7 R, T = look_at_view_transform(dists, elevs, azims) meshes = mesh.extend(num_render) renders = utils.render_mesh(meshes, R, T, device, img_size=224, silhouette=True) # converting from [num_render, 224, 224, 4] silhouette render (only channel 4 has info) # to [num_render, 224, 224, 3] rgb image (black/white) renders_binary_rgb = torch.unsqueeze(renders[..., 3], 3).repeat(1, 1, 1, 3) loss = torch.sigmoid( semantic_discriminator_net(renders_binary_rgb.permute(0, 3, 1, 2))) loss = torch.mean(loss) return loss, renders_binary_rgb
def refine_mesh_batched(self, deform_net, semantic_dis_net, mesh_verts_batch, img_batch, pose_batch, compute_losses=True): # computing mesh deformation delta_v = deform_net(pose_batch, img_batch, mesh_verts_batch) delta_v = delta_v.reshape((-1,3)) deformed_mesh = mesh.offset_verts(delta_v) if not compute_losses: return deformed_mesh else: # prep inputs used to compute losses pred_dist = pose_batch[:,0] pred_elev = pose_batch[:,1] pred_azim = pose_batch[:,2] R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) mask = rgba_image[:,:,3] > 0 mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device) num_vertices = mesh.verts_packed().shape[0] zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device) sym_plane_normal = [0,0,1] # TODO: make this generalizable to other classes loss_dict = {} # computing losses rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True) loss_dict["sil_loss"] = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt) loss_dict["l2_loss"] = F.mse_loss(delta_v, zero_deformation_tensor) loss_dict["lap_smoothness_loss"] = mesh_laplacian_smoothing(deformed_mesh) loss_dict["normal_consistency_loss"] = mesh_normal_consistency(deformed_mesh) # TODO: remove weights? if self.img_sym_lam > 0: loss_dict["img_sym_loss"], _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.cfg["training"]["img_sym_num_azim"], self.device) else: loss_dict["img_sym_loss"] = torch.tensor(0).to(self.device) if self.vertex_sym_lam > 0: loss_dict["vertex_sym_loss"] = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device) else: loss_dict["vertex_sym_loss"] = torch.tensor(0).to(self.device) if self.semantic_dis_lam > 0: loss_dict["semantic_dis_loss"], _ = compute_sem_dis_loss(deformed_mesh, self.semantic_dis_loss_num_render, semantic_dis_net, self.device) else: loss_dict["semantic_dis_loss"] = torch.tensor(0).to(self.device) return loss_dict, deformed_mesh
def image_symmetry_loss(mesh, sym_plane, num_azim, device, render_silhouettes=True): N = np.array([sym_plane]) if np.linalg.norm(N) != 1: raise ValueError("sym_plane needs to be a unit normal") # camera positions for one half of the sphere. Offset allows for better angle even when num_azim = 1 num_views_on_half = num_azim * 2 offset = 15 #azims = torch.linspace(0,90,num_azim+2)[1:-1].repeat(2) azims = torch.linspace(0 + offset, 90 - offset, num_azim).repeat(2) elevs = torch.Tensor([-45 for i in range(num_azim)] + [45 for i in range(num_azim)]) dists = torch.ones(num_views_on_half) * 1.9 R_half_1, T_half_1 = look_at_view_transform(dists, elevs, azims) R = [R_half_1] # compute the other half of camera positions according to plane of symmetry reflect_matrix = torch.tensor(np.eye(3) - 2 * N.T @ N, dtype=torch.float) for i in range(num_views_on_half): camera_position = pytorch3d.renderer.cameras.camera_position_from_spherical_angles( dists[i], elevs[i], azims[i]) R_sym = pytorch3d.renderer.cameras.look_at_rotation( camera_position @ reflect_matrix) R.append(R_sym) R = torch.cat(R) T = torch.cat([T_half_1, T_half_1]) # rendering meshes = mesh.extend(num_views_on_half * 2) if render_silhouettes: renders = utils.render_mesh(meshes, R, T, device, img_size=224, silhouette=True)[..., 3] else: renders = utils.render_mesh(meshes, R, T, device, img_size=224, silhouette=False) # a sym_triple is [R1, R1_flipped, R2] sym_triples = [] for i in range(num_views_on_half): sym_triples.append([ renders[i], torch.flip(renders[i], [1]), renders[i + num_views_on_half] ]) # calculating loss. sym_loss = 0 for sym_triple in sym_triples: sym_loss += F.mse_loss(sym_triple[1], sym_triple[2]) sym_loss = sym_loss / num_views_on_half return sym_loss, sym_triples
def refine_mesh(self, mesh, rgba_image, pred_dist, pred_elev, pred_azim, record_intermediate=False): ''' Args: pred_dist (int) pred_elev (int) pred_azim (int) rgba_image (np int array, 224 x 224 x 4, rgba, 0-255) ''' # prep inputs used during training image = rgba_image[:,:,:3] image_in = torch.unsqueeze(torch.tensor(image/255, dtype=torch.float).permute(2,0,1),0).to(self.device) mask = rgba_image[:,:,3] > 0 mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device) pose_in = torch.unsqueeze(torch.tensor([pred_dist, pred_elev, pred_azim]),0).to(self.device) verts_in = torch.unsqueeze(mesh.verts_packed(),0).to(self.device) R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) num_vertices = mesh.verts_packed().shape[0] zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device) # prep network & optimizer deform_net = DeformationNetwork(self.cfg, num_vertices, self.device) deform_net.to(self.device) optimizer = optim.Adam(deform_net.parameters(), lr=self.cfg["training"]["learning_rate"]) # optimizing loss_info = pd.DataFrame() deformed_meshes = [] for i in tqdm(range(self.num_iterations)): deform_net.train() optimizer.zero_grad() # computing mesh deformation & its render at the input pose delta_v = deform_net(pose_in, image_in, verts_in) delta_v = delta_v.reshape((-1,3)) deformed_mesh = mesh.offset_verts(delta_v) rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True) # computing losses l2_loss = F.mse_loss(delta_v, zero_deformation_tensor) lap_smoothness_loss = mesh_laplacian_smoothing(deformed_mesh) normal_consistency_loss = mesh_normal_consistency(deformed_mesh) sil_loss = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt) sym_plane_normal = [0,0,1] if self.img_sym_lam > 0: img_sym_loss, _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.img_sym_num_azim, self.device) else: img_sym_loss = torch.tensor(0).to(self.device) if self.vertex_sym_lam > 0: vertex_sym_loss = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device) else: vertex_sym_loss = torch.tensor(0).to(self.device) if self.semantic_dis_lam > 0: semantic_dis_loss, _ = self.semantic_loss_computer.compute_loss(deformed_mesh) else: semantic_dis_loss = torch.tensor(0).to(self.device) # optimization step on weighted losses total_loss = (sil_loss*self.sil_lam + l2_loss*self.l2_lam + lap_smoothness_loss*self.lap_lam + normal_consistency_loss*self.normals_lam + img_sym_loss*self.img_sym_lam + vertex_sym_loss*self.vertex_sym_lam + semantic_dis_loss*self.semantic_dis_lam) total_loss.backward() optimizer.step() # saving info iter_loss_info = {"iter":i, "sil_loss": sil_loss.item(), "l2_loss": l2_loss.item(), "lap_smoothness_loss":lap_smoothness_loss.item(), "normal_consistency_loss": normal_consistency_loss.item(),"img_sym_loss": img_sym_loss.item(), "vertex_sym_loss": vertex_sym_loss.item(), "semantic_dis_loss": semantic_dis_loss.item(), "total_loss": total_loss.item()} loss_info = loss_info.append(iter_loss_info, ignore_index = True) if record_intermediate and (i % 100 == 0 or i == self.num_iterations-1): print(i) deformed_meshes.append(deformed_mesh) if record_intermediate: return deformed_meshes, loss_info else: return deformed_mesh, loss_info