def compute_sem_dis_loss(self, mesh, num_render, semantic_discriminator_net,
                         device):
    # need to make sure this matches render settings for discriminator training set
    # TODO: alternatively, randomize the angles each time?
    # 0.,  45.,  90., 135., 180., 225., 270., 315.
    azims = torch.linspace(0, 360, num_render + 1)[:-1]
    elevs = torch.Tensor([25 for i in range(num_render)])
    dists = torch.ones(num_render) * 1.7
    R, T = look_at_view_transform(dists, elevs, azims)

    meshes = mesh.extend(num_render)
    renders = utils.render_mesh(meshes,
                                R,
                                T,
                                device,
                                img_size=224,
                                silhouette=True)
    # converting from [num_render, 224, 224, 4] silhouette render (only channel 4 has info)
    # to [num_render, 224, 224, 3] rgb image (black/white)
    renders_binary_rgb = torch.unsqueeze(renders[..., 3], 3).repeat(1, 1, 1, 3)

    loss = torch.sigmoid(
        semantic_discriminator_net(renders_binary_rgb.permute(0, 3, 1, 2)))
    loss = torch.mean(loss)

    return loss, renders_binary_rgb
Exemple #2
0
    def refine_mesh_batched(self, deform_net, semantic_dis_net, mesh_verts_batch, img_batch, pose_batch, compute_losses=True):

        # computing mesh deformation
        delta_v = deform_net(pose_batch, img_batch, mesh_verts_batch)
        delta_v = delta_v.reshape((-1,3))
        deformed_mesh = mesh.offset_verts(delta_v)

        if not compute_losses:
            return deformed_mesh

        else:
            # prep inputs used to compute losses
            pred_dist = pose_batch[:,0]
            pred_elev = pose_batch[:,1]
            pred_azim = pose_batch[:,2]
            R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) 
            mask = rgba_image[:,:,3] > 0
            mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device)
            num_vertices = mesh.verts_packed().shape[0]
            zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device)
            sym_plane_normal = [0,0,1] # TODO: make this generalizable to other classes

            loss_dict = {}
            # computing losses
            rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True)
            loss_dict["sil_loss"] = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt)
            loss_dict["l2_loss"] = F.mse_loss(delta_v, zero_deformation_tensor)
            loss_dict["lap_smoothness_loss"] = mesh_laplacian_smoothing(deformed_mesh)
            loss_dict["normal_consistency_loss"] = mesh_normal_consistency(deformed_mesh)

            # TODO: remove weights?
            if self.img_sym_lam > 0:
                loss_dict["img_sym_loss"], _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.cfg["training"]["img_sym_num_azim"], self.device)
            else:
                loss_dict["img_sym_loss"] = torch.tensor(0).to(self.device)
            if self.vertex_sym_lam > 0:
                loss_dict["vertex_sym_loss"] = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device)
            else:
                loss_dict["vertex_sym_loss"] = torch.tensor(0).to(self.device)
            if self.semantic_dis_lam > 0:
                loss_dict["semantic_dis_loss"], _ = compute_sem_dis_loss(deformed_mesh, self.semantic_dis_loss_num_render, semantic_dis_net, self.device)
            else:
                loss_dict["semantic_dis_loss"] = torch.tensor(0).to(self.device)

            return loss_dict, deformed_mesh
def image_symmetry_loss(mesh,
                        sym_plane,
                        num_azim,
                        device,
                        render_silhouettes=True):
    N = np.array([sym_plane])
    if np.linalg.norm(N) != 1:
        raise ValueError("sym_plane needs to be a unit normal")

    # camera positions for one half of the sphere. Offset allows for better angle even when num_azim = 1
    num_views_on_half = num_azim * 2
    offset = 15
    #azims = torch.linspace(0,90,num_azim+2)[1:-1].repeat(2)
    azims = torch.linspace(0 + offset, 90 - offset, num_azim).repeat(2)
    elevs = torch.Tensor([-45 for i in range(num_azim)] +
                         [45 for i in range(num_azim)])
    dists = torch.ones(num_views_on_half) * 1.9
    R_half_1, T_half_1 = look_at_view_transform(dists, elevs, azims)
    R = [R_half_1]

    # compute the other half of camera positions according to plane of symmetry
    reflect_matrix = torch.tensor(np.eye(3) - 2 * N.T @ N, dtype=torch.float)
    for i in range(num_views_on_half):
        camera_position = pytorch3d.renderer.cameras.camera_position_from_spherical_angles(
            dists[i], elevs[i], azims[i])
        R_sym = pytorch3d.renderer.cameras.look_at_rotation(
            camera_position @ reflect_matrix)
        R.append(R_sym)
    R = torch.cat(R)
    T = torch.cat([T_half_1, T_half_1])

    # rendering
    meshes = mesh.extend(num_views_on_half * 2)
    if render_silhouettes:
        renders = utils.render_mesh(meshes,
                                    R,
                                    T,
                                    device,
                                    img_size=224,
                                    silhouette=True)[..., 3]
    else:
        renders = utils.render_mesh(meshes,
                                    R,
                                    T,
                                    device,
                                    img_size=224,
                                    silhouette=False)

    # a sym_triple is [R1, R1_flipped, R2]
    sym_triples = []
    for i in range(num_views_on_half):
        sym_triples.append([
            renders[i],
            torch.flip(renders[i], [1]), renders[i + num_views_on_half]
        ])

    # calculating loss.
    sym_loss = 0
    for sym_triple in sym_triples:
        sym_loss += F.mse_loss(sym_triple[1], sym_triple[2])
    sym_loss = sym_loss / num_views_on_half

    return sym_loss, sym_triples
Exemple #4
0
    def refine_mesh(self, mesh, rgba_image, pred_dist, pred_elev, pred_azim, record_intermediate=False):
        '''
        Args:
        pred_dist (int)
        pred_elev (int)
        pred_azim (int)
        rgba_image (np int array, 224 x 224 x 4, rgba, 0-255)
        '''

        # prep inputs used during training
        image = rgba_image[:,:,:3]
        image_in = torch.unsqueeze(torch.tensor(image/255, dtype=torch.float).permute(2,0,1),0).to(self.device)
        mask = rgba_image[:,:,3] > 0
        mask_gt = torch.tensor(mask, dtype=torch.float).to(self.device)
        pose_in = torch.unsqueeze(torch.tensor([pred_dist, pred_elev, pred_azim]),0).to(self.device)
        verts_in = torch.unsqueeze(mesh.verts_packed(),0).to(self.device)

        R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) 
        num_vertices = mesh.verts_packed().shape[0]
        zero_deformation_tensor = torch.zeros((num_vertices, 3)).to(self.device)

        # prep network & optimizer
        deform_net = DeformationNetwork(self.cfg, num_vertices, self.device)
        deform_net.to(self.device)
        optimizer = optim.Adam(deform_net.parameters(), lr=self.cfg["training"]["learning_rate"])

        # optimizing  
        loss_info = pd.DataFrame()
        deformed_meshes = []
        for i in tqdm(range(self.num_iterations)):
            deform_net.train()
            optimizer.zero_grad()
            
            # computing mesh deformation & its render at the input pose
            delta_v = deform_net(pose_in, image_in, verts_in)
            delta_v = delta_v.reshape((-1,3))
            deformed_mesh = mesh.offset_verts(delta_v)
            rendered_deformed_mesh = utils.render_mesh(deformed_mesh, R, T, self.device, img_size=224, silhouette=True)

            # computing losses
            l2_loss = F.mse_loss(delta_v, zero_deformation_tensor)
            lap_smoothness_loss = mesh_laplacian_smoothing(deformed_mesh)
            normal_consistency_loss = mesh_normal_consistency(deformed_mesh)

            sil_loss = F.binary_cross_entropy(rendered_deformed_mesh[0, :,:, 3], mask_gt)

            sym_plane_normal = [0,0,1]
            if self.img_sym_lam > 0:
                img_sym_loss, _ = def_losses.image_symmetry_loss(deformed_mesh, sym_plane_normal, self.img_sym_num_azim, self.device)
            else:
                img_sym_loss = torch.tensor(0).to(self.device)
            if self.vertex_sym_lam > 0:
                vertex_sym_loss = def_losses.vertex_symmetry_loss_fast(deformed_mesh, sym_plane_normal, self.device)
            else:
                vertex_sym_loss = torch.tensor(0).to(self.device)
            if self.semantic_dis_lam > 0:
                semantic_dis_loss, _ = self.semantic_loss_computer.compute_loss(deformed_mesh)
            else:
                semantic_dis_loss = torch.tensor(0).to(self.device)

            # optimization step on weighted losses
            total_loss = (sil_loss*self.sil_lam + l2_loss*self.l2_lam + lap_smoothness_loss*self.lap_lam +
                          normal_consistency_loss*self.normals_lam + img_sym_loss*self.img_sym_lam + 
                          vertex_sym_loss*self.vertex_sym_lam + semantic_dis_loss*self.semantic_dis_lam)
            total_loss.backward()
            optimizer.step()

            # saving info
            iter_loss_info = {"iter":i, "sil_loss": sil_loss.item(), "l2_loss": l2_loss.item(), 
                              "lap_smoothness_loss":lap_smoothness_loss.item(),
                              "normal_consistency_loss": normal_consistency_loss.item(),"img_sym_loss": img_sym_loss.item(),
                              "vertex_sym_loss": vertex_sym_loss.item(), "semantic_dis_loss": semantic_dis_loss.item(),
                              "total_loss": total_loss.item()}
            loss_info = loss_info.append(iter_loss_info, ignore_index = True)
            if record_intermediate and (i % 100 == 0 or i == self.num_iterations-1):
                print(i)
                deformed_meshes.append(deformed_mesh)

        if record_intermediate:
            return deformed_meshes, loss_info
        else:
            return deformed_mesh, loss_info