Beispiel #1
0
	def test_mesh_obj_material(self, config):
		could_load, checkpoint_counter = self.load(self.checkpoint_dir)
		if could_load:
			print(" [*] Load SUCCESS")
		else:
			print(" [!] Load failed...")
			return
		
		w2 = self.sess.run(self.cw2, feed_dict={})

		dima = self.test_size
		dim = self.real_size
		multiplier = int(dim/dima)
		multiplier2 = multiplier*multiplier

		#write material
		#all output shapes share the same material
		#which means the same convex always has the same color for different shapes
		#change the colors in default.mtl to visualize correspondences between shapes
		fout2 = open(config.sample_dir+"/default.mtl", 'w')
		for i in range(self.c_dim):
			fout2.write("newmtl m"+str(i+1)+"\n") #material id
			fout2.write("Kd 0.80 0.80 0.80\n") #color (diffuse) RGB 0.00-1.00
			fout2.write("Ka 0 0 0\n") #color (ambient) leave 0s
		fout2.close()


		for t in range(config.start, min(len(self.data_pixels),config.end)):
			model_float = np.ones([self.real_size,self.real_size,self.real_size,self.c_dim],np.float32)
			batch_view = self.data_pixels[t:t+1,self.test_idx].astype(np.float32)/255.0
			out_z = self.sess.run(self.sE,
				feed_dict={
					self.view: batch_view,
				})
			out_m, out_b = self.sess.run([self.zE_m, self.zE_b],
				feed_dict={
					self.z_vector: out_z,
				})
			for i in range(multiplier):
				for j in range(multiplier):
					for k in range(multiplier):
						minib = i*multiplier2+j*multiplier+k
						model_out = self.sess.run(self.zG2,
							feed_dict={
								self.plane_m: out_m,
								self.plane_b: out_b,
								self.point_coord: self.coords[minib:minib+1],
							})
						model_float[self.aux_x+i,self.aux_y+j,self.aux_z+k,:] = np.reshape(model_out, [self.test_size,self.test_size,self.test_size,self.c_dim])
			
			bsp_convex_list = []
			color_idx_list = []
			model_float = model_float<0.01
			model_float_sum = np.sum(model_float,axis=3)
			for i in range(self.c_dim):
				slice_i = model_float[:,:,:,i]
				if np.max(slice_i)>0: #if one voxel is inside a convex
					if np.min(model_float_sum-slice_i*2)>=0: #if this convex is redundant, i.e. the convex is inside the shape
						model_float_sum = model_float_sum-slice_i
					else:
						box = []
						for j in range(self.p_dim):
							if w2[j,i]>0.01:
								a = -out_m[0,0,j]
								b = -out_m[0,1,j]
								c = -out_m[0,2,j]
								d = -out_b[0,0,j]
								box.append([a,b,c,d])
						if len(box)>0:
							bsp_convex_list.append(np.array(box,np.float32))
							color_idx_list.append(i)
			
			#print(bsp_convex_list)
			print(len(bsp_convex_list))
			
			#convert bspt to mesh
			vertices = []

			#write obj
			fout2 = open(config.sample_dir+"/"+str(t)+"_bsp.obj", 'w')
			fout2.write("mtllib default.mtl\n")

			for i in range(len(bsp_convex_list)):
				vg, tg = get_mesh([bsp_convex_list[i]])
				vbias=len(vertices)+1
				vertices = vertices+vg

				fout2.write("usemtl m"+str(color_idx_list[i]+1)+"\n")
				for ii in range(len(vg)):
					fout2.write("v "+str(vg[ii][0])+" "+str(vg[ii][1])+" "+str(vg[ii][2])+"\n")
				for ii in range(len(tg)):
					fout2.write("f")
					for jj in range(len(tg[ii])):
						fout2.write(" "+str(tg[ii][jj]+vbias))
					fout2.write("\n")

			fout2.close()
Beispiel #2
0
    def test_bsp(self, config):
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            return

        w2 = self.sess.run(self.cw2, feed_dict={})

        dima = self.test_size
        dim = self.real_size
        multiplier = int(dim / dima)
        multiplier2 = multiplier * multiplier

        for t in range(config.start, min(len(self.data_voxels), config.end)):
            model_float = np.ones(
                [self.real_size, self.real_size, self.real_size, self.c_dim],
                np.float32)
            batch_voxels = self.data_voxels[t:t + 1]
            out_m, out_b = self.sess.run([self.sE_m, self.sE_b],
                                         feed_dict={
                                             self.vox3d: batch_voxels,
                                         })
            for i in range(multiplier):
                for j in range(multiplier):
                    for k in range(multiplier):
                        minib = i * multiplier2 + j * multiplier + k
                        model_out = self.sess.run(self.zG2,
                                                  feed_dict={
                                                      self.plane_m:
                                                      out_m,
                                                      self.plane_b:
                                                      out_b,
                                                      self.point_coord:
                                                      self.coords[minib:minib +
                                                                  1],
                                                  })
                        model_float[self.aux_x + i, self.aux_y + j,
                                    self.aux_z + k, :] = np.reshape(
                                        model_out, [
                                            self.test_size, self.test_size,
                                            self.test_size, self.c_dim
                                        ])

            bsp_convex_list = []
            model_float = model_float < 0.01
            model_float_sum = np.sum(model_float, axis=3)
            for i in range(self.c_dim):
                slice_i = model_float[:, :, :, i]
                if np.max(slice_i) > 0:  #if one voxel is inside a convex
                    if np.min(
                            model_float_sum - slice_i * 2
                    ) >= 0:  #if this convex is redundant, i.e. the convex is inside the shape
                        model_float_sum = model_float_sum - slice_i
                    else:
                        box = []
                        for j in range(self.p_dim):
                            if w2[j, i] > 0.01:
                                a = -out_m[0, 0, j]
                                b = -out_m[0, 1, j]
                                c = -out_m[0, 2, j]
                                d = -out_b[0, 0, j]
                                box.append([a, b, c, d])
                        if len(box) > 0:
                            bsp_convex_list.append(np.array(box, np.float32))

            #print(bsp_convex_list)
            print(len(bsp_convex_list))

            #convert bspt to mesh
            vertices, polygons = get_mesh(bsp_convex_list)
            #use the following alternative to merge nearby vertices to get watertight meshes
            #vertices, polygons = get_mesh_watertight(bsp_convex_list)

            #output ply
            write_ply_polygon(config.sample_dir + "/" + str(t) + "_bsp.ply",
                              vertices, polygons)
Beispiel #3
0
	def test_mesh_point(self, config):
		could_load, checkpoint_counter = self.load(self.checkpoint_dir)
		if could_load:
			print(" [*] Load SUCCESS")
		else:
			print(" [!] Load failed...")
			return
			
		w2 = self.sess.run(self.cw2, feed_dict={})
		dima = self.test_size
		dim = self.real_size
		multiplier = int(dim/dima)
		multiplier2 = multiplier*multiplier
		for t in range(config.start, min(len(self.data_pixels),config.end)):
			print(t)
			model_float = np.ones([self.real_size,self.real_size,self.real_size,self.c_dim],np.float32)
			model_float_combined = np.ones([self.real_size,self.real_size,self.real_size],np.float32)
			batch_view = self.data_pixels[t:t+1,self.test_idx].astype(np.float32)/255.0
			out_z = self.sess.run(self.sE,
				feed_dict={
					self.view: batch_view,
				})
			out_m, out_b = self.sess.run([self.zE_m, self.zE_b],
				feed_dict={
					self.z_vector: out_z,
				})
			for i in range(multiplier):
				for j in range(multiplier):
					for k in range(multiplier):
						minib = i*multiplier2+j*multiplier+k
						model_out, model_out_combined = self.sess.run([self.zG2, self.zG],
							feed_dict={
								self.plane_m: out_m,
								self.plane_b: out_b,
								self.point_coord: self.coords[minib:minib+1],
							})
						model_float[self.aux_x+i,self.aux_y+j,self.aux_z+k,:] = np.reshape(model_out, [self.test_size,self.test_size,self.test_size,self.c_dim])
						model_float_combined[self.aux_x+i,self.aux_y+j,self.aux_z+k] = np.reshape(model_out_combined, [self.test_size,self.test_size,self.test_size])

			bsp_convex_list = []
			model_float = model_float<0.01
			model_float_sum = np.sum(model_float,axis=3)
			for i in range(self.c_dim):
				slice_i = model_float[:,:,:,i]
				if np.max(slice_i)>0: #if one voxel is inside a convex
					#if np.min(model_float_sum-slice_i*2)>=0: #if this convex is redundant, i.e. the convex is inside the shape
					#	model_float_sum = model_float_sum-slice_i
					#else:
						box = []
						for j in range(self.p_dim):
							if w2[j,i]>0.01:
								a = -out_m[0,0,j]
								b = -out_m[0,1,j]
								c = -out_m[0,2,j]
								d = -out_b[0,0,j]
								box.append([a,b,c,d])
						if len(box)>0:
							bsp_convex_list.append(np.array(box,np.float32))

			#convert bspt to mesh
			vertices, polygons = get_mesh(bsp_convex_list)
			#use the following alternative to merge nearby vertices to get watertight meshes
			#vertices, polygons = get_mesh_watertight(bsp_convex_list)

			#output ply
			write_ply_polygon(config.sample_dir+"/"+str(t)+"_bsp.ply", vertices, polygons)
			
			#sample surface points
			sampled_points_normals = sample_points_polygon_vox64(vertices, polygons, model_float_combined, 16000)
			#check point inside shape or not
			sample_points_value = self.sess.run(self.zG,
				feed_dict={
					self.plane_m: out_m,
					self.plane_b: out_b,
					self.point_coord: np.reshape(sampled_points_normals[:,:3]+sampled_points_normals[:,3:]*1e-4, [1,-1,3]),
				})
			sampled_points_normals = sampled_points_normals[sample_points_value[0,:,0]>1e-4]
			print(len(bsp_convex_list), len(sampled_points_normals))
			np.random.shuffle(sampled_points_normals)
			write_ply_point_normal(config.sample_dir+"/"+str(t)+"_pc.ply", sampled_points_normals[:4096])
Beispiel #4
0
    def test_mesh_obj_material(self, config):
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            return

        w2 = self.bsp_network.generator.convex_layer_weights.numpy()

        dima = self.test_size
        dim = self.real_size
        multiplier = int(dim / dima)
        multiplier2 = multiplier * multiplier

        #write material
        #all output shapes share the same material
        #which means the same convex always has the same color for different shapes
        #change the colors in default.mtl to visualize correspondences between shapes
        fout2 = open(config.sample_dir + "/default.mtl", 'w')
        for i in range(self.c_dim):
            fout2.write("newmtl m" + str(i + 1) + "\n")  #material id
            fout2.write("Kd 0.80 0.80 0.80\n")  #color (diffuse) RGB 0.00-1.00
            fout2.write("Ka 0 0 0\n")  #color (ambient) leave 0s
        fout2.close()

        for t in range(config.start, min(len(self.data_voxels), config.end)):
            model_float = np.ones(
                [self.real_size, self.real_size, self.real_size, self.c_dim],
                np.float32)
            batch_voxels = self.data_voxels[t:t + 1]
            _, out_m, _, _ = self.bsp_network(batch_voxels,
                                              None,
                                              None,
                                              None,
                                              is_training=False)
            for i in range(multiplier):
                for j in range(multiplier):
                    for k in range(multiplier):
                        minib = i * multiplier2 + j * multiplier + k
                        point_coord = self.coords[minib:minib + 1]
                        _, _, model_out, _ = self.bsp_network(
                            None, None, out_m, point_coord, is_training=False)
                        model_float[self.aux_x + i, self.aux_y + j,
                                    self.aux_z + k, :] = np.reshape(
                                        model_out, [
                                            self.test_size, self.test_size,
                                            self.test_size, self.c_dim
                                        ])

            out_m = out_m.numpy()

            bsp_convex_list = []
            color_idx_list = []
            model_float = model_float < 0.01
            model_float_sum = np.sum(model_float, axis=3)
            for i in range(self.c_dim):
                slice_i = model_float[:, :, :, i]
                if np.max(slice_i) > 0:  #if one voxel is inside a convex
                    if np.min(
                            model_float_sum - slice_i * 2
                    ) >= 0:  #if this convex is redundant, i.e. the convex is inside the shape
                        model_float_sum = model_float_sum - slice_i
                    else:
                        box = []
                        for j in range(self.p_dim):
                            if w2[j, i] > 0.01:
                                a = -out_m[0, 0, j]
                                b = -out_m[0, 1, j]
                                c = -out_m[0, 2, j]
                                d = -out_m[0, 3, j]
                                box.append([a, b, c, d])
                        if len(box) > 0:
                            bsp_convex_list.append(np.array(box, np.float32))
                            color_idx_list.append(i)

            #print(bsp_convex_list)
            print(len(bsp_convex_list))

            #convert bspt to mesh
            vertices = []

            #write obj
            fout2 = open(config.sample_dir + "/" + str(t) + "_bsp.obj", 'w')
            fout2.write("mtllib default.mtl\n")

            for i in range(len(bsp_convex_list)):
                vg, tg = get_mesh([bsp_convex_list[i]])
                vbias = len(vertices) + 1
                vertices = vertices + vg

                fout2.write("usemtl m" + str(color_idx_list[i] + 1) + "\n")
                for ii in range(len(vg)):
                    fout2.write("v " + str(vg[ii][0]) + " " + str(vg[ii][1]) +
                                " " + str(vg[ii][2]) + "\n")
                for ii in range(len(tg)):
                    fout2.write("f")
                    for jj in range(len(tg[ii])):
                        fout2.write(" " + str(tg[ii][jj] + vbias))
                    fout2.write("\n")

            fout2.close()
Beispiel #5
0
    def test_mesh_point(self, config):
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            return

        w2 = self.bsp_network.generator.convex_layer_weights.numpy()
        dima = self.test_size
        dim = self.real_size
        multiplier = int(dim / dima)
        multiplier2 = multiplier * multiplier
        for t in range(config.start, min(len(self.data_voxels), config.end)):
            print(t)
            model_float = np.ones(
                [self.real_size, self.real_size, self.real_size, self.c_dim],
                np.float32)
            model_float_combined = np.ones(
                [self.real_size, self.real_size, self.real_size], np.float32)
            batch_voxels = self.data_voxels[t:t + 1]
            _, out_m, _, _ = self.bsp_network(batch_voxels,
                                              None,
                                              None,
                                              None,
                                              is_training=False)
            for i in range(multiplier):
                for j in range(multiplier):
                    for k in range(multiplier):
                        minib = i * multiplier2 + j * multiplier + k
                        point_coord = self.coords[minib:minib + 1]
                        _, _, model_out, model_out_combined = self.bsp_network(
                            None, None, out_m, point_coord, is_training=False)
                        model_float[self.aux_x + i, self.aux_y + j,
                                    self.aux_z + k, :] = np.reshape(
                                        model_out, [
                                            self.test_size, self.test_size,
                                            self.test_size, self.c_dim
                                        ])
                        model_float_combined[self.aux_x + i, self.aux_y + j,
                                             self.aux_z + k] = np.reshape(
                                                 model_out_combined, [
                                                     self.test_size,
                                                     self.test_size,
                                                     self.test_size
                                                 ])

            out_m = out_m.numpy()

            bsp_convex_list = []
            model_float = model_float < 0.01
            model_float_sum = np.sum(model_float, axis=3)
            for i in range(self.c_dim):
                slice_i = model_float[:, :, :, i]
                if np.max(slice_i) > 0:  #if one voxel is inside a convex
                    #if np.min(model_float_sum-slice_i*2)>=0: #if this convex is redundant, i.e. the convex is inside the shape
                    #	model_float_sum = model_float_sum-slice_i
                    #else:
                    box = []
                    for j in range(self.p_dim):
                        if w2[j, i] > 0.01:
                            a = -out_m[0, 0, j]
                            b = -out_m[0, 1, j]
                            c = -out_m[0, 2, j]
                            d = -out_m[0, 3, j]
                            box.append([a, b, c, d])
                    if len(box) > 0:
                        bsp_convex_list.append(np.array(box, np.float32))

            #convert bspt to mesh
            vertices, polygons = get_mesh(bsp_convex_list)
            #use the following alternative to merge nearby vertices to get watertight meshes
            #vertices, polygons = get_mesh_watertight(bsp_convex_list)

            #output ply
            write_ply_polygon(config.sample_dir + "/" + str(t) + "_bsp.ply",
                              vertices, polygons)

            #sample surface points
            sampled_points_normals = sample_points_polygon_vox64(
                vertices, polygons, model_float_combined, 16000)
            #check point inside shape or not
            point_coord = np.reshape(
                sampled_points_normals[:, :3] +
                sampled_points_normals[:, 3:] * 1e-4, [1, -1, 3])
            point_coord = np.concatenate([
                point_coord,
                np.ones([1, point_coord.shape[1], 1], np.float32)
            ],
                                         axis=2)
            _, _, _, sample_points_value = self.bsp_network(None,
                                                            None,
                                                            out_m,
                                                            point_coord,
                                                            is_training=False)
            sampled_points_normals = sampled_points_normals[
                sample_points_value[0, :, 0] > 1e-4]
            print(len(bsp_convex_list), len(sampled_points_normals))
            np.random.shuffle(sampled_points_normals)
            write_ply_point_normal(
                config.sample_dir + "/" + str(t) + "_pc.ply",
                sampled_points_normals[:4096])
Beispiel #6
0
    def test_bsp(self, config):
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            return

        w2 = self.bsp_network.generator.convex_layer_weights.numpy()

        dima = self.test_size
        dim = self.real_size
        multiplier = int(dim / dima)
        multiplier2 = multiplier * multiplier

        for t in range(config.start, min(len(self.data_voxels), config.end)):
            model_float = np.ones(
                [self.real_size, self.real_size, self.real_size, self.c_dim],
                np.float32)
            batch_voxels = self.data_voxels[t:t + 1]
            _, out_m, _, _ = self.bsp_network(batch_voxels,
                                              None,
                                              None,
                                              None,
                                              is_training=False)
            for i in range(multiplier):
                for j in range(multiplier):
                    for k in range(multiplier):
                        minib = i * multiplier2 + j * multiplier + k
                        point_coord = self.coords[minib:minib + 1]
                        _, _, model_out, _ = self.bsp_network(
                            None, None, out_m, point_coord, is_training=False)
                        model_float[self.aux_x + i, self.aux_y + j,
                                    self.aux_z + k, :] = np.reshape(
                                        model_out, [
                                            self.test_size, self.test_size,
                                            self.test_size, self.c_dim
                                        ])

            out_m = out_m.numpy()

            bsp_convex_list = []
            model_float = model_float < 0.01
            model_float_sum = np.sum(model_float, axis=3)
            for i in range(self.c_dim):
                slice_i = model_float[:, :, :, i]
                if np.max(slice_i) > 0:  #if one voxel is inside a convex
                    if np.min(
                            model_float_sum - slice_i * 2
                    ) >= 0:  #if this convex is redundant, i.e. the convex is inside the shape
                        model_float_sum = model_float_sum - slice_i
                    else:
                        box = []
                        for j in range(self.p_dim):
                            if w2[j, i] > 0.01:
                                a = -out_m[0, 0, j]
                                b = -out_m[0, 1, j]
                                c = -out_m[0, 2, j]
                                d = -out_m[0, 3, j]
                                box.append([a, b, c, d])
                        if len(box) > 0:
                            bsp_convex_list.append(np.array(box, np.float32))

            #print(bsp_convex_list)
            print(len(bsp_convex_list))

            #convert bspt to mesh
            vertices, polygons = get_mesh(bsp_convex_list)
            #use the following alternative to merge nearby vertices to get watertight meshes
            #vertices, polygons = get_mesh_watertight(bsp_convex_list)

            #output ply
            write_ply_polygon(config.sample_dir + "/" + str(t) + "_bsp.ply",
                              vertices, polygons)
    def test_real_data(self, config, pic_path):
        checkpoint_txt = os.path.join(self.checkpoint_path, "checkpoint")
        if os.path.exists(checkpoint_txt):
            fin = open(checkpoint_txt)
            model_dir = fin.readline().strip()
            fin.close()
            self.bsp_network.load_state_dict(torch.load(model_dir))
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            return

        w2 = self.bsp_network.generator.convex_layer_weights.detach().cpu(
        ).numpy()
        dima = self.test_size
        dim = self.real_size
        multiplier = int(dim / dima)
        multiplier2 = multiplier * multiplier

        self.bsp_network.eval()

        model_float = np.ones(
            [self.real_size, self.real_size, self.real_size, self.c_dim],
            np.float32)
        model_float_combined = np.ones(
            [self.real_size, self.real_size, self.real_size], np.float32)

        t = 100000
        batch_view = self.load_image(pic_path)  # FIXME: Add figure here

        batch_view = torch.from_numpy(batch_view)
        batch_view = batch_view.to(self.device)
        _, out_m, _, _ = self.bsp_network(batch_view,
                                          None,
                                          None,
                                          None,
                                          is_training=False)
        for i in range(multiplier):
            for j in range(multiplier):
                for k in range(multiplier):
                    minib = i * multiplier2 + j * multiplier + k
                    point_coord = self.coords[minib:minib + 1]
                    _, _, model_out, model_out_combined = self.bsp_network(
                        None, None, out_m, point_coord, is_training=False)
                    model_float[self.aux_x + i, self.aux_y + j,
                                self.aux_z + k, :] = np.reshape(
                                    model_out.detach().cpu().numpy(), [
                                        self.test_size, self.test_size,
                                        self.test_size, self.c_dim
                                    ])
                    model_float_combined[
                        self.aux_x + i, self.aux_y + j,
                        self.aux_z + k] = np.reshape(
                            model_out_combined.detach().cpu().numpy(),
                            [self.test_size, self.test_size, self.test_size])

        out_m_ = out_m.detach().cpu().numpy()
        bsp_convex_list = []
        model_float = model_float < 0.01
        model_float_sum = np.sum(model_float, axis=3)
        for i in range(self.c_dim):
            slice_i = model_float[:, :, :, i]
            if np.max(slice_i) > 0:  #if one voxel is inside a convex
                #if np.min(model_float_sum-slice_i*2)>=0: #if this convex is redundant, i.e. the convex is inside the shape
                #	model_float_sum = model_float_sum-slice_i
                #else:
                box = []
                for j in range(self.p_dim):
                    if w2[j, i] > 0.01:
                        a = -out_m_[0, 0, j]
                        b = -out_m_[0, 1, j]
                        c = -out_m_[0, 2, j]
                        d = -out_m_[0, 3, j]
                        box.append([a, b, c, d])
                if len(box) > 0:
                    bsp_convex_list.append(np.array(box, np.float32))

        #convert bspt to mesh
        vertices, polygons = get_mesh(bsp_convex_list)
        #use the following alternative to merge nearby vertices to get watertight meshes
        #vertices, polygons = get_mesh_watertight(bsp_convex_list)

        #output ply
        write_ply_polygon(config.sample_dir + "/" + str(t) + "_bsp.ply",
                          vertices, polygons)

        #sample surface points
        sampled_points_normals = sample_points_polygon_vox64(
            vertices, polygons, model_float_combined, 16000)
        #check point inside shape or not
        point_coord = np.reshape(
            sampled_points_normals[:, :3] +
            sampled_points_normals[:, 3:] * 1e-4, [1, -1, 3])
        point_coord = np.concatenate(
            [point_coord,
             np.ones([1, point_coord.shape[1], 1], np.float32)],
            axis=2)
        point_coord = torch.from_numpy(point_coord)
        point_coord = point_coord.to(self.device)
        _, _, _, sample_points_value = self.bsp_network(None,
                                                        None,
                                                        out_m,
                                                        point_coord,
                                                        is_training=False)
        sample_points_value = sample_points_value.detach().cpu().numpy()
        sampled_points_normals = sampled_points_normals[
            sample_points_value[0, :, 0] > 1e-4]
        print(len(bsp_convex_list), len(sampled_points_normals))
        np.random.shuffle(sampled_points_normals)
        write_ply_point_normal(config.sample_dir + "/" + str(t) + "_pc.ply",
                               sampled_points_normals[:4096])
    def test_bsp(self, config):
        #load previous checkpoint
        checkpoint_txt = os.path.join(self.checkpoint_path, "checkpoint")
        if os.path.exists(checkpoint_txt):
            fin = open(checkpoint_txt)
            model_dir = fin.readline().strip()
            fin.close()
            self.bsp_network.load_state_dict(torch.load(model_dir))
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            return

        w2 = self.bsp_network.generator.convex_layer_weights.detach().cpu(
        ).numpy()

        dima = self.test_size
        dim = self.real_size
        multiplier = int(dim / dima)
        multiplier2 = multiplier * multiplier

        self.bsp_network.eval()
        for t in range(config.start, min(len(self.data_pixels), config.end)):
            model_float = np.ones(
                [self.real_size, self.real_size, self.real_size, self.c_dim],
                np.float32)
            batch_view = self.data_pixels[t:t + 1, self.test_idx].astype(
                np.float32) / 255.0
            batch_view = torch.from_numpy(batch_view)
            batch_view = batch_view.to(self.device)
            _, out_m, _, _ = self.bsp_network(batch_view,
                                              None,
                                              None,
                                              None,
                                              is_training=False)
            for i in range(multiplier):
                for j in range(multiplier):
                    for k in range(multiplier):
                        minib = i * multiplier2 + j * multiplier + k
                        point_coord = self.coords[minib:minib + 1]
                        _, _, model_out, _ = self.bsp_network(
                            None, None, out_m, point_coord, is_training=False)
                        model_float[self.aux_x + i, self.aux_y + j,
                                    self.aux_z + k, :] = np.reshape(
                                        model_out.detach().cpu().numpy(), [
                                            self.test_size, self.test_size,
                                            self.test_size, self.c_dim
                                        ])

            out_m = out_m.detach().cpu().numpy()

            bsp_convex_list = []
            model_float = model_float < 0.01
            model_float_sum = np.sum(model_float, axis=3)
            for i in range(self.c_dim):
                slice_i = model_float[:, :, :, i]
                if np.max(slice_i) > 0:  #if one voxel is inside a convex
                    if np.min(
                            model_float_sum - slice_i * 2
                    ) >= 0:  #if this convex is redundant, i.e. the convex is inside the shape
                        model_float_sum = model_float_sum - slice_i
                    else:
                        box = []
                        for j in range(self.p_dim):
                            if w2[j, i] > 0.01:
                                a = -out_m[0, 0, j]
                                b = -out_m[0, 1, j]
                                c = -out_m[0, 2, j]
                                d = -out_m[0, 3, j]
                                box.append([a, b, c, d])
                        if len(box) > 0:
                            bsp_convex_list.append(np.array(box, np.float32))

            #print(bsp_convex_list)
            print(len(bsp_convex_list))

            #convert bspt to mesh
            vertices, polygons = get_mesh(bsp_convex_list)
            #use the following alternative to merge nearby vertices to get watertight meshes
            #vertices, polygons = get_mesh_watertight(bsp_convex_list)

            #output ply
            write_ply_polygon(config.sample_dir + "/" + str(t) + "_bsp.ply",
                              vertices, polygons)