def __pool_main(self, mesh_index):
        mesh = self.__meshes[mesh_index]
        if self.__edge_priorities is None:
            # queue = self.__build_l2_queue(self.__fe[mesh_index, :, :mesh.edges_count], mesh.edges_count)
            edge_priorities = torch.sum(
                self.__fe[mesh_index, :, :mesh.edges_count] *
                self.__fe[mesh_index, :, :mesh.edges_count], 0)
        else:
            # queue = self.__build_priority_queue(self.__edge_priorities[mesh_index, :mesh.edges_count], mesh.edges_count)
            edge_priorities = self.__edge_priorities[
                mesh_index, :mesh.edges_count]
        queue = self.__build_priority_queue(edge_priorities,
                                            mesh.edges_count)  ####

        mesh.export(edge_priorities=edge_priorities)

        mask = np.ones(mesh.edges_count, dtype=np.bool)
        edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)
        while mesh.edges_count > self.__out_target:
            value, edge_id = heappop(queue)
            edge_id = int(edge_id)
            if mask[edge_id]:
                self.__pool_edge(mesh, edge_id, mask, edge_groups)
        mesh.clean(mask, edge_groups)
        fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask,
                                          self.__out_target)
        self.__updated_fe[mesh_index] = fe
示例#2
0
    def __pool_main(self, mesh_index):
        mesh = self.__meshes[mesh_index]
        queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.edges_count],
                                   mesh.edges_count)
        # recycle = []
        # last_queue_len = len(queue)
        last_count = mesh.edges_count + 1
        mask = np.ones(mesh.edges_count, dtype=np.bool)
        edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)

        self.to_merge_edges = []
        #self.to_remove_edges = []
        while mesh.edges_count > self.__out_target:
            try:
                value, edge_id = heappop(queue)
            except:
                print(self.__out_target)
                print(mesh.edges_count)
                raise
            edge_id = int(edge_id)
            if mask[edge_id]:
                self.__pool_edge(mesh, edge_id, mask, edge_groups)

        MeshPool.__union_multiple_groups(mesh, edge_groups,
                                         self.to_merge_edges)
        #for k in self.to_remove_edges:
        #MeshPool.__remove_group(mesh, edge_groups, k)
        #mesh.remove_edge(k)
        mesh.clean(mask, edge_groups)
        fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask,
                                          self.__out_target)
        self.__updated_fe[mesh_index] = fe
示例#3
0
 def __pool_main(self, mesh_index):
     mesh = self.__meshes[mesh_index]
     queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.edges_count],
                                mesh.edges_count)
     mask = np.ones(mesh.edges_count, dtype=np.bool)
     edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)
     while mesh.edges_count > self.__out_target:
         value, edge_id = heappop(queue)
         edge_id = int(edge_id)
         if mask[edge_id]:
             self.__pool_edge(mesh, edge_id, mask, edge_groups)
     mesh.clean(mask, edge_groups)
     fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask,
                                       self.__out_target)
     self.__updated_fe[mesh_index] = fe
示例#4
0
 def __pool_main(self, mesh_index):
     mesh = self.__meshes[mesh_index]
     fe = self.__fe[mesh_index, :, :mesh.edges_count]
     in_fe_sq = torch.sum(fe**2, dim=0)
     sorted, edge_ids = torch.sort(in_fe_sq, descending=True)
     edge_ids = edge_ids.tolist()
     mask = np.ones(mesh.edges_count, dtype=np.bool)
     edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)
     while mesh.edges_count > self.__out_target:
         edge_id = edge_ids.pop()
         if mask[edge_id]:
             self.__pool_edge(mesh, edge_id, mask, edge_groups)
     mesh.clean(mask, edge_groups)
     fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask,
                                       self.__out_target)
     self.__updated_fe[mesh_index] = fe
示例#5
0
 def __pool_main(self, mesh_index):
     mesh = self.__meshes[mesh_index]
     queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.faces_count],
                                mesh.faces_count)
     mask = np.ones(mesh.faces_count, dtype=np.bool)
     face_groups = MeshUnion(mesh.faces_count, self.__fe.device)
     while mesh.faces_count > self.__out_target:
         # print("face count " + str(mesh.faces_count))
         value, face_id = heappop(queue)
         face_id = int(face_id)
         if mask[face_id]:
             self.__pool_face(mesh, face_id, mask, face_groups)
     mesh.clean(mask, face_groups)
     """ mesh.export(name = 'pool'+str(self.__out_target))"""
     fe = face_groups.rebuild_features(self.__fe[mesh_index], mask,
                                       self.__out_target)
     self.__updated_fe[mesh_index] = fe
示例#6
0
 def __unpool_main(self, mesh_index, unroll_target):
     mesh = self.__meshes[mesh_index]
     queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.faces_count],
                                mesh.faces_count, mesh)
     face_groups = MeshUnion(mesh.faces_count, self.__fe.device)
     vs_groups = MeshUnion(self.v_count[mesh_index], self.__fe.device)
     not_split = np.ones(mesh.faces_count, dtype=np.bool)
     while mesh.faces_count + 6 < unroll_target:
         # print("face count " + str(mesh.faces_count))
         value, face_id = heappop(queue)
         face_id = int(face_id)
         if self.check_valid(mesh, face_id) and not_split[face_id]:
             not_split = self.__unpool_face(mesh, mesh_index, face_id,
                                            face_groups, vs_groups,
                                            not_split)
     mesh.pool_count -= 1
     mask = np.ones(mesh.faces_count, dtype=np.bool)
     # mesh.export(name='unpool')
     fe = face_groups.rebuild_features(self.__fe[mesh_index], mask,
                                       self.unroll_target)
     padding_b = self.vs_target - vs_groups.groups.shape[1]
     if padding_b > 0:
         padding_b = ConstantPad2d((0, padding_b), 0)
         vs_groups.groups = padding_b(vs_groups.groups)
     vs = vs_groups.rebuild_vs_average(self.__vs[mesh_index],
                                       self.vs_target)
     self.__updated_fe[mesh_index] = fe
     self.__updated_vs[mesh_index] = vs
示例#7
0
 def init_history(self):
     self.history_data = {
         'groups': [],
         'gemm_edges': [self.gemm_edges.copy()],
         'occurrences': [],
         'old2current': np.arange(self.edges_count, dtype=np.int32),
         'current2old': np.arange(self.edges_count, dtype=np.int32),
         'edges_mask': [torch.ones(self.edges_count, dtype=torch.uint8)],
         'edges_count': [self.edges_count],
     }
     if self.export_folder:
         self.history_data['collapses'] = MeshUnion(self.edges_count)