def load_sdf_net(filename=None, return_latent_codes=False): from model.sdf_net import SDFNet, LATENT_CODES_FILENAME sdf_net = SDFNet() if filename is not None: sdf_net.filename = filename sdf_net.load() sdf_net.eval() if return_latent_codes: latent_codes = torch.load(LATENT_CODES_FILENAME).to(device) latent_codes.requires_grad = False return sdf_net, latent_codes else: return sdf_net
voxels_current, level=0, spacing=(size / voxel_resolution, size / voxel_resolution, size / voxel_resolution)) vertices -= size / 2 mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) point_cloud = mesh.sample(point_cloud_size) rescale_point_cloud(point_cloud, method=rescale) result[i, :, :] = point_cloud return result if 'sample' in sys.argv: sdf_net = SDFNet() sdf_net.filename = 'hybrid_gan_generator.to' sdf_net.load() sdf_net.eval() clouds = sample_point_clouds(sdf_net, 1000, 2048, voxel_resolution=32) np.save('data/generated_point_cloud_sample.npy', clouds) if 'checkpoints' in sys.argv: import glob from tqdm import tqdm torch.manual_seed(1234) files = glob.glob( 'models/checkpoints/hybrid_progressive_gan_generator_2-epoch-*.to', recursive=True) latent_codes = standard_normal_distribution.sample(
POINTCLOUD_SIZE = 200000 points = torch.load('data/sdf_points.to').to(device) sdf = torch.load('data/sdf_values.to').to(device) MODEL_COUNT = points.shape[0] // POINTCLOUD_SIZE BATCH_SIZE = 20000 SDF_CUTOFF = 0.1 sdf.clamp_(-SDF_CUTOFF, SDF_CUTOFF) signs = sdf.cpu().numpy() > 0 SIGMA = 0.01 LOG_FILE_NAME = "plots/sdf_net_training.csv" sdf_net = SDFNet() if "continue" in sys.argv: sdf_net.load() latent_codes = torch.load(LATENT_CODES_FILENAME).to(device) else: normal_distribution = torch.distributions.normal.Normal(0, 0.0001) latent_codes = normal_distribution.sample( (MODEL_COUNT, LATENT_CODE_SIZE)).to(device) latent_codes.requires_grad = True network_optimizer = optim.Adam(sdf_net.parameters(), lr=1e-5) latent_code_optimizer = optim.Adam([latent_codes], lr=1e-5) criterion = nn.MSELoss() first_epoch = 0 if 'continue' in sys.argv:
dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') dataloader = DataLoader(dataset, batch_size=1000, num_workers=8) latent_codes = torch.zeros((len(dataset), LATENT_CODE_SIZE)) with torch.no_grad(): position = 0 for batch in tqdm(dataloader): latent_codes[position:position + batch.shape[0], :] = vae.encode( batch.to(device)).detach().cpu() latent_codes = latent_codes.numpy() else: from model.sdf_net import SDFNet, LATENT_CODES_FILENAME latent_codes = torch.load(LATENT_CODES_FILENAME).detach().cpu().numpy() sdf_net = SDFNet() sdf_net.load() sdf_net.eval() from shapenet_metadata import shapenet raise NotImplementedError('A labels tensor needs to be supplied here.') labels = None print("Calculating embedding...") tsne = TSNE(n_components=2) latent_codes_embedded = tsne.fit_transform(latent_codes) print("Calculating clusters...") kmeans = KMeans(n_clusters=SAMPLE_COUNT) indices = np.zeros(SAMPLE_COUNT, dtype=int) kmeans_clusters = kmeans.fit_predict(latent_codes_embedded)
mesh = trimesh.load(MODEL_PATH) points, sdf = sample_sdf_near_surface(mesh) save_images = 'save' in sys.argv if save_images: viewer = MeshRenderer(start_thread=False, size=1080) ensure_directory('images') else: viewer = MeshRenderer() points = torch.tensor(points, dtype=torch.float32, device=device) sdf = torch.tensor(sdf, dtype=torch.float32, device=device) sdf.clamp_(-0.1, 0.1) sdf_net = SDFNet(latent_code_size=LATENT_CODE_SIZE).to(device) optimizer = torch.optim.Adam(sdf_net.parameters(), lr=1e-5) BATCH_SIZE = 20000 latent_code = torch.zeros((BATCH_SIZE, LATENT_CODE_SIZE), device=device) indices = torch.zeros(BATCH_SIZE, dtype=torch.int64, device=device) positive_indices = (sdf > 0).nonzero().squeeze().cpu().numpy() negative_indices = (sdf < 0).nonzero().squeeze().cpu().numpy() step = 0 error_targets = np.logspace(np.log10(0.02), np.log10(0.0005), num=500) image_index = 0 while True: try:
] checkpoints_network = sorted(checkpoints_network) checkpoints_latent_codes = sorted(checkpoints_latent_codes) indices = [ i * (len(checkpoints_network) - 1) // (COUNT - 1) for i in range(COUNT) ] checkpoints_network = [checkpoints_network[i] for i in indices] checkpoints_latent_codes = [checkpoints_latent_codes[i] for i in indices] print('\n'.join(checkpoints_network)) MODEL_INDEX = 1000 print(MODEL_INDEX) from model.sdf_net import SDFNet sdf_net = SDFNet() sdf_net.eval() plot = ImageGrid(COUNT, create_viewer=False) for i in range(COUNT): sdf_net.load_state_dict( torch.load(os.path.join(CHECKPOINT_PATH, checkpoints_network[i]))) latent_codes = torch.load( os.path.join(CHECKPOINT_PATH, checkpoints_latent_codes[i])).detach() latent_code = latent_codes[MODEL_INDEX, :] plot.set_image(render_image(sdf_net, latent_code, crop=True), i) plot.save('plots/deepsdf-checkpoints.pdf') if "deepsdf-interpolation-stl" in sys.argv:
from datasets import VoxelDataset from torch.utils.data import DataLoader LEARN_RATE = 0.00001 BATCH_SIZE = 8 CRITIC_UPDATES_PER_GENERATOR_UPDATE = 5 CRITIC_WEIGHT_LIMIT = 0.01 dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy') dataset.rescale_sdf = False data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8) generator = SDFNet() generator.filename = 'hybrid_wgan_generator.to' critic = Discriminator() critic.filename = 'hybrid_wgan_critic.to' critic.use_sigmoid = False if "continue" in sys.argv: generator.load() critic.load() LOG_FILE_NAME = "plots/hybrid_wgan_training.csv" first_epoch = 0 if 'continue' in sys.argv: log_file_contents = open(LOG_FILE_NAME, 'r').readlines() first_epoch = len(log_file_contents)
import sys from collections import deque from tqdm import tqdm from model.sdf_net import SDFNet from model.gan import Discriminator, LATENT_CODE_SIZE from util import create_text_slice, device, standard_normal_distribution, get_voxel_coordinates VOXEL_RESOLUTION = 32 SDF_CLIPPING = 0.1 from util import create_text_slice from datasets import VoxelDataset from torch.utils.data import DataLoader generator = SDFNet() generator.filename = 'hybrid_gan_generator.to' discriminator = Discriminator() discriminator.filename = 'hybrid_gan_discriminator.to' if "continue" in sys.argv: generator.load() discriminator.load() LOG_FILE_NAME = "plots/hybrid_gan_training.csv" first_epoch = 0 if 'continue' in sys.argv: log_file_contents = open(LOG_FILE_NAME, 'r').readlines() first_epoch = len(log_file_contents)
import time import torch from tqdm import tqdm import cv2 import random import sys SAMPLE_COUNT = 30 # Number of distinct objects to generate and interpolate between TRANSITION_FRAMES = 60 ROTATE_MODEL = False USE_HYBRID_GAN = True SURFACE_LEVEL = 0.04 if USE_HYBRID_GAN else 0.011 sdf_net = SDFNet() if USE_HYBRID_GAN: sdf_net.filename = 'hybrid_progressive_gan_generator_3.to' sdf_net.load() sdf_net.eval() if USE_HYBRID_GAN: codes = standard_normal_distribution.sample((SAMPLE_COUNT + 1, LATENT_CODE_SIZE)).numpy() else: latent_codes = torch.load(LATENT_CODES_FILENAME).detach().cpu().numpy() indices = random.sample(list(range(latent_codes.shape[0])), SAMPLE_COUNT + 1) codes = latent_codes[indices, :] codes[0, :] = codes[-1, :] # Make animation periodic spline = scipy.interpolate.CubicSpline(np.arange(SAMPLE_COUNT + 1), codes, axis=0, bc_type='periodic')
VOXEL_RESOLUTION = RESOLUTIONS[ITERATION] dataset = VoxelDataset.from_split( 'data/chairs/voxels_{:d}/{{:s}}.npy'.format(VOXEL_RESOLUTION), 'data/chairs/train.txt') data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) def get_generator_filename(iteration): return 'hybrid_progressive_gan_generator_{:d}.to'.format(iteration) generator = SDFNet(device='cpu') discriminator = Discriminator() if not CONTINUE and ITERATION > 0: generator.filename = get_generator_filename(ITERATION - 1) generator.load() discriminator.set_iteration(ITERATION - 1) discriminator.load() discriminator.set_iteration(ITERATION) generator.filename = get_generator_filename(ITERATION) if CONTINUE: generator.load() discriminator.load() if torch.cuda.device_count() > 1: print("Using dataparallel with {:d} GPUs.".format( torch.cuda.device_count()))