Ejemplo n.º 1
0
SIGMA = 0.01

LOG_FILE_NAME = "plots/sdf_net_training.csv"

sdf_net = SDFNet()
if "continue" in sys.argv:
    sdf_net.load()
    latent_codes = torch.load(LATENT_CODES_FILENAME).to(device)
else:
    normal_distribution = torch.distributions.normal.Normal(0, 0.0001)
    latent_codes = normal_distribution.sample(
        (MODEL_COUNT, LATENT_CODE_SIZE)).to(device)
latent_codes.requires_grad = True

network_optimizer = optim.Adam(sdf_net.parameters(), lr=1e-5)
latent_code_optimizer = optim.Adam([latent_codes], lr=1e-5)
criterion = nn.MSELoss()

first_epoch = 0
if 'continue' in sys.argv:
    log_file_contents = open(LOG_FILE_NAME, 'r').readlines()
    first_epoch = len(log_file_contents)

log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w")


def create_batches():
    indices_positive = np.nonzero(signs)[0]
    indices_negative = np.nonzero(~signs)[0]
    if indices_negative.shape[0] > indices_positive.shape[0]:
Ejemplo n.º 2
0
critic.filename = 'hybrid_wgan_critic.to'
critic.use_sigmoid = False

if "continue" in sys.argv:
    generator.load()
    critic.load()

LOG_FILE_NAME = "plots/hybrid_wgan_training.csv"
first_epoch = 0
if 'continue' in sys.argv:
    log_file_contents = open(LOG_FILE_NAME, 'r').readlines()
    first_epoch = len(log_file_contents)

log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w")

generator_optimizer = optim.Adam(generator.parameters(), lr=LEARN_RATE)

critic_criterion = torch.nn.functional.binary_cross_entropy
critic_optimizer = optim.RMSprop(critic.parameters(), lr=LEARN_RATE)

show_viewer = "nogui" not in sys.argv

if show_viewer:
    from rendering import MeshRenderer
    viewer = MeshRenderer()

valid_target = torch.ones(BATCH_SIZE, requires_grad=False).to(device)
fake_target = torch.zeros(BATCH_SIZE, requires_grad=False).to(device)


def sample_latent_codes():
Ejemplo n.º 3
0
points, sdf = sample_sdf_near_surface(mesh)

save_images = 'save' in sys.argv

if save_images:
    viewer = MeshRenderer(start_thread=False, size=1080)
    ensure_directory('images')
else:
    viewer = MeshRenderer()

points = torch.tensor(points, dtype=torch.float32, device=device)
sdf = torch.tensor(sdf, dtype=torch.float32, device=device)
sdf.clamp_(-0.1, 0.1)

sdf_net = SDFNet(latent_code_size=LATENT_CODE_SIZE).to(device)
optimizer = torch.optim.Adam(sdf_net.parameters(), lr=1e-5)

BATCH_SIZE = 20000
latent_code = torch.zeros((BATCH_SIZE, LATENT_CODE_SIZE), device=device)
indices = torch.zeros(BATCH_SIZE, dtype=torch.int64, device=device)

positive_indices = (sdf > 0).nonzero().squeeze().cpu().numpy()
negative_indices = (sdf < 0).nonzero().squeeze().cpu().numpy()

step = 0
error_targets = np.logspace(np.log10(0.02), np.log10(0.0005), num=500)
image_index = 0

while True:
    try:
        indices[:BATCH_SIZE // 2] = torch.tensor(np.random.choice(
Ejemplo n.º 4
0
discriminator = Discriminator()
discriminator.filename = 'hybrid_gan_discriminator.to'

if "continue" in sys.argv:
    generator.load()
    discriminator.load()

LOG_FILE_NAME = "plots/hybrid_gan_training.csv"
first_epoch = 0
if 'continue' in sys.argv:
    log_file_contents = open(LOG_FILE_NAME, 'r').readlines()
    first_epoch = len(log_file_contents)

log_file = open(LOG_FILE_NAME, "a" if "continue" in sys.argv else "w")

generator_optimizer = optim.Adam(generator.parameters(), lr=0.001)

discriminator_criterion = torch.nn.functional.binary_cross_entropy
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.00001)

show_viewer = "nogui" not in sys.argv

if show_viewer:
    from rendering import MeshRenderer
    viewer = MeshRenderer()

BATCH_SIZE = 8

dataset = VoxelDataset.glob('data/chairs/voxels_32/**.npy')
dataset.rescale_sdf = False
data_loader = DataLoader(dataset,