pos_gt = torch.stack((x_gt, y_gt, z_gt), dim=-1) imgs_gt = [] print("Rendering GT images...") for i in trange(numsteps): _vertices = vertices_gt.clone() + pos_gt[i] rgba = renderer.forward(_vertices, faces_gt, textures) imgs_gt.append(rgba) logdir = Path(args.logdir) / args.expid if args.log: logdir.mkdir(exist_ok=True) if args.log: write_imglist_to_gif(imgs_gt, logdir / "gt.gif", imgformat="rgba", verbose=False) if args.save_timelapse: timelapse = kaolin.visualize.Timelapse(args.logdir) # Estimate the length of the pendulum by backprop. parameters = [] if args.optimize_length: length_est = torch.nn.Parameter(torch.tensor([0.5], device=device), requires_grad=False) lengthmodel = SimpleModel(length_est).to(device) parameters += list(lengthmodel.parameters()) if args.optimize_gravity:
) + 1e-5 # print("method:", args.method, initial_velocity) # if args.method == "gradsim": # loss = ( # torch.nn.functional.mse_loss(imgs[-1], target_image) # + torch.nn.functional.mse_loss(imgs[-2], target_image) # + torch.nn.functional.mse_loss(imgs[-3], target_image) # + torch.nn.functional.mse_loss(imgs[-4], target_image) # + torch.nn.functional.mse_loss(imgs[-5], target_image) # ) if args.log: write_imglist_to_gif( imgs, os.path.join(logdir, f"{e:02d}.gif"), imgformat="rgba", verbose=False, ) write_imglist_to_dir( imgs, os.path.join(logdir, f"{e:02d}"), imgformat="rgba", ) imageio.imwrite( os.path.join(logdir, f"last_frame_{e:02d}.png"), (imgs[-1][0].permute(1, 2, 0).detach().cpu().numpy() * 255).astype( np.uint8 ), ) write_meshes_to_file( vertices, faces.detach().cpu().numpy(), os.path.join(logdir, f"vertices_{e:05d}")
# sys.exit(0) # imgs_gt.append(rgb) # SoftRas rgba = renderer.forward( state.q.unsqueeze(0).to(device), faces.unsqueeze(0).to(device), textures.to(device), ) imgs_gt.append(rgba) cloth_path = Path("cache/cloth") cloth_path.mkdir(exist_ok=True) write_imglist_to_gif(imgs_gt, cloth_path / "gt.gif", imgformat="rgba", verbose=False) init_inv_mass = particle_inv_mass_gt.clone() init_inv_mass[-1] = 0.0 init_inv_mass[-2] = 0.0 massmodel = SimpleModel( # particle_inv_mass_gt + 50 * torch.rand_like(particle_inv_mass_gt), init_inv_mass, activation=torch.nn.functional.relu, ) velocitymodel = SimpleModel(-0.01 * torch.ones_like(particle_velocity_gt), activation=None) epochs = 50 save_gif_every = 1 compare_every = 1
# print(state.rigid_x.shape) # v_in = torch.from_numpy(np.asarray(model_gt.shape_geo_src[0].vertices)).float() # print(torch.allclose(v_in, vertices)) vertices_current = get_world_vertices( vertices, state_gt.rigid_r.view(-1), state_gt.rigid_x) rgba = renderer.forward( vertices_current.unsqueeze(0).to(device), faces.unsqueeze(0).to(device), textures.to(device), ) imgs_gt.append(rgba) positions_gt.append(state_gt.rigid_x) logvertices_gt.append(vertices_current.detach().cpu().numpy()) if args.log: write_imglist_to_gif(imgs_gt, os.path.join(logdir, "gt.gif")) write_meshes_to_file(logvertices_gt, faces.detach().cpu().numpy(), os.path.join(logdir, "vertices_gt")) # """ # Optimize for physical parameters. # """ # # ke, kd, kf, mu # shape_material_guesses = (9500, 950, 950, 0.5) # builder = df.sim.ModelBuilder() # rigid = builder.add_rigid_body( # pos=pos_gt, rot=rot_gt, vel=vel_gt, omega=omega_gt, # )
colors_bxpx3=textures_gt, # uv_bxpx2=uv_gt, # texture_bx3xthxtw=textureimg_gt, # lightparam=lightparam, ) rgba = torch.cat((img_gt, alpha_gt), dim=-1) # rgba = renderer.forward( # body_gt.get_world_vertices().unsqueeze(0), faces_gt, textures_gt # ) imgs_gt.append(rgba) # writer.append_data((255 * img).astype(np.uint8)) # writer.close() if args.log: write_imglist_to_gif(imgs_gt, os.path.join(logdir, "gt.gif"), imgformat="dibr") # Load the template mesh (usually a sphere). mesh = TriangleMesh.from_obj(args.template) vertices = meshutils.normalize_vertices( mesh.vertices.unsqueeze(0)).to(device) faces = mesh.faces.to(device).unsqueeze(0) # uv = get_spherical_coords_x(vertices[0].cpu().numpy()) # uv = torch.from_numpy(uv).to(device).float().unsqueeze(0) / 255.0 # Paint the mesh yellow textures = torch.stack( ( torch.ones( 1, vertices.shape[-2], dtype=torch.float32, device=device), torch.ones(
) imgs.append(rgba) if not args.inference: loss.backward() optimizer.step() optimizer.zero_grad() tqdm.write(f"Loss: {loss.item():.5}") render_time += 1 if args.inference: filename = os.path.join("cache", "jellyfish", "debug", "inference") else: filename = os.path.join("cache", "jellyfish", "debug", f"{render_time:02d}") write_imglist_to_gif(imgs, f"{filename}.gif", imgformat="rgba", verbose=False) if args.inference: filename = os.path.join("cache", "jellyfish", "debug", "inference") else: filename = os.path.join( "cache", "jellyfish", "debug", f"last_frame_{render_time:02d}" ) imageio.imwrite( f"{filename}.png", (imgs[-1][0].permute(1, 2, 0).detach().cpu().numpy() * 255).astype( np.uint8 ), ) torch.save(network, "cache/jellyfish/debug/model.pt")