def test_one(opt, cage_shape, new_source, new_source_face, new_target, new_target_face): states = torch.load(opt.ckpt) if "states" in states: states = states["states"] pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.ply"), states["source_vertices"][0].transpose( 0, 1).detach().cpu(), states["source_faces"][0].detach().cpu()) # states["template_vertices"] = cage_shape.transpose(1, 2) # states["source_vertices"] = new_source.transpose(1, 2) # states["source_faces"] = new_source_face pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-Sa.ply"), new_source[0].detach().cpu(), new_source_face[0].detach().cpu()) pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-Sb.ply"), new_target[0].detach().cpu(), new_target_face[0].detach().cpu()) net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=512, template_vertices=cage_shape.transpose(1, 2), template_faces=states["template_faces"].cuda(), source_vertices=new_source.transpose(1, 2), source_faces=new_source_face).cuda() net.eval() load_network(net, states) outputs = net(new_target.transpose(1, 2).contiguous()) deformed = outputs["deformed"] pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-Sab.ply"), deformed[0].detach().cpu(), new_target_face[0].detach().cpu())
def train(): dataset = build_dataset(opt) dataloader = torch.utils.data.DataLoader( dataset, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id)) source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float) source_face = dataset.mesh_face.unsqueeze(0) cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float) cage_face = dataset.cage_face.unsqueeze(0) mesh = Mesh(vertices=cage_shape[0], faces=cage_face[0]) build_gemm(mesh, cage_face[0]) cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda() cage_edges = edge_vertex_indices(cage_face[0]) # network net = networks.FixedSourceDeformer( opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size, template_vertices=cage_shape.transpose(1, 2), template_faces=cage_face, source_vertices=source_shape.transpose(1, 2), source_faces=source_face).cuda() print(net) net.apply(weights_init) if opt.ckpt: load_network(net, opt.ckpt) net.train() all_losses = losses.AllLosses(opt) # optimizer optimizer = torch.optim.Adam([{ 'params': net.nd_decoder.parameters() }, { "params": net.encoder.parameters() }], lr=opt.lr) # train os.makedirs(opt.log_dir, exist_ok=True) shutil.copy2(__file__, opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "network2.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"), opt.log_dir) pymesh.save_mesh_raw( os.path.join(opt.log_dir, "t{:06d}_Sa.ply".format(0)), net.source_vertices[0].transpose(0, 1).detach().cpu().numpy(), net.source_faces[0].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, "t{:06d}_template.ply".format(0)), net.template_vertices[0].transpose(0, 1).detach().cpu().numpy(), net.template_faces[0].detach().cpu()) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, max(int(opt.nepochs * 0.75), 1), gamma=0.5, last_epoch=-1) # train net.train() t = 0 start_epoch = 0 warmed_up = False mvc_weight = opt.mvc_weight opt.mvc_weight = 0 os.makedirs(opt.log_dir, exist_ok=True) running_avg_loss = -1 log_file = open(os.path.join(opt.log_dir, "loss_log.txt"), "a") log_interval = min(max(len(dataloader) // 5, 50), 200) save_interval = max(opt.nepochs // 10, 1) with torch.autograd.detect_anomaly(): if opt.epoch: start_epoch = opt.epoch % opt.nepochs t += start_epoch * len(dataloader) for epoch in range(start_epoch, opt.nepochs): for epoch_t, data in enumerate(dataloader): progress = epoch_t / len(dataloader) + epoch warming_up = progress < opt.warmup_epochs if (opt.deform_template or opt.optimize_template) and ( progress >= opt.warmup_epochs) and (not warmed_up): if opt.deform_template: optimizer.add_param_group({ 'params': net.nc_decoder.parameters(), 'lr': 0.1 * opt.lr }) if opt.optimize_template: optimizer.add_param_group({ 'params': net.template_vertices, 'lr': 0.1 * opt.lr }) warmed_up = True # start to compute mvc weight opt.mvc_weight = mvc_weight save_network(net, opt.log_dir, network_label="net", epoch_label="warmed_up") ############# get data ########### data = dataset.uncollate(data) data["cage_edge_points"] = cage_edge_points data["cage_edges"] = cage_edges data["source_shape"] = net.source_vertices.detach() data["source_face"] = net.source_faces.detach() ############# run network ########### optimizer.zero_grad() target_shape_t = data["target_shape"].transpose(1, 2) sample_idx = None if "sample_idx" in data: sample_idx = data["sample_idx"] if data["source_normals"] is not None: data["source_normals"] = torch.gather( data["source_normals"], 1, sample_idx.unsqueeze(-1).expand(-1, -1, 3)) outputs = net(target_shape_t, sample_idx) if opt.sfnormal_weight > 0 and ("source_mesh" in data and "source_mesh" is not None): if outputs["deformed"].shape[1] == data[ "source_mesh"].shape[1]: outputs["deformed_hr"] = outputs["deformed"] else: outputs["deformed_hr"] = deform_with_MVC( outputs["cage"].expand( data["source_mesh"].shape[0], -1, -1).detach(), outputs["new_cage"], outputs["cage_face"].expand( data["source_mesh"].shape[0], -1, -1), data["source_mesh"]) data["source_shape"] = outputs["source_shape"] ############# get losses ########### current_loss = all_losses(data, outputs, progress) loss_sum = torch.sum( torch.stack([v for v in current_loss.values()], dim=0)) if running_avg_loss < 0: running_avg_loss = loss_sum else: running_avg_loss = running_avg_loss + ( loss_sum.item() - running_avg_loss) / (t + 1) if (t % log_interval == 0) or (loss_sum > 10 * running_avg_loss): log_str = "warming up {} e {:03d} t {:05d}: {}".format( not warmed_up, epoch, t, ", ".join([ "{} {:.3g}".format(k, v.mean().item()) for k, v in current_loss.items() ])) print(log_str) log_file.write(log_str + "\n") log_outputs(opt, t, outputs, data) # save_ply(data["target_shape"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(t))) # save_ply(outputs["deformed"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sab.ply".format(t))) # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage1.ply".format(t)), # outputs["cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True) # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage2.ply".format(t)), # outputs["new_cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True) if loss_sum > 100 * running_avg_loss: logger.info( "loss ({}) > 10*running_average_loss ({}). Skip without update." .format(loss_sum, 5 * running_avg_loss)) torch.cuda.empty_cache() continue loss_sum.backward() if opt.alternate_cd: optimize_C = (progress > opt.warmup_epochs) and ( t % (opt.c_step + opt.d_step)) > opt.d_step if optimize_C: net.nd_decoder.zero_grad() net.encoder.zero_grad() else: try: net.nc_decoder.zero_grad() except AttributeError: net.template_vertices.grad.zero_() # clamp_gradient_norm(net, 1) optimizer.step() if (t + 1) % 500 == 0: save_network(net, opt.log_dir, network_label="net", epoch_label="latest") t += 1 if (epoch + 1) % save_interval == 0: save_network(net, opt.log_dir, network_label="net", epoch_label=epoch) scheduler.step() log_file.close() save_network(net, opt.log_dir, network_label="net", epoch_label="final") test_all(net=net)
def test(net=None, subdir="test"): opt.phase = "test" if isinstance(opt.target_model, str): opt.target_model = [opt.target_model] if net is None: states = torch.load(opt.ckpt) if "states" in states: states = states["states"] if opt.template: cage_shape, cage_face = read_trimesh(opt.template) cage_shape = torch.from_numpy( cage_shape[:, :3]).unsqueeze(0).float() cage_face = torch.from_numpy(cage_face).unsqueeze(0).long() states["template_vertices"] = cage_shape.transpose(1, 2) states["template_faces"] = cage_face if opt.source_model: source_shape, source_face = read_trimesh(opt.source_model) source_shape = torch.from_numpy( source_shape[:, :3]).unsqueeze(0).float() source_face = torch.from_numpy(source_face).unsqueeze(0).long() states["source_vertices"] = source_shape.transpose(1, 2) states["source_faces"] = source_shape net = networks.FixedSourceDeformer( opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size, template_vertices=states["template_vertices"], template_faces=states["template_faces"], source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda() load_network(net, states) net = net.cuda() net.eval() else: net.eval() print(net) test_output_dir = os.path.join(opt.log_dir, subdir) os.makedirs(test_output_dir, exist_ok=True) with torch.no_grad(): for target_model in opt.target_model: assert (os.path.isfile(target_model)) target_face = None target_shape, target_face = read_trimesh(target_model) # target_shape = read_ply(target_model)[:,:3] # target_shape, _, scale = normalize_to_box(target_shape) # normalize acording to height y axis # target_shape = target_shape/2*1.7 target_shape = torch.from_numpy( target_shape[:, :3]).cuda().float().unsqueeze(0) if target_face is None: target_face = net.source_faces else: target_face = torch.from_numpy( target_face).cuda().long().unsqueeze(0) t_filename = os.path.splitext(os.path.basename(target_model))[0] source_mesh = net.source_vertices.transpose(1, 2).detach() source_face = net.source_faces.detach() # furthest sampling target_shape_sampled = furthest_point_sample( target_shape, net.source_vertices.shape[2], NCHW=False)[1] # target_shape_sampled = (target_shape[:, np.random.permutation(target_shape.shape[1]), :]).contiguous() outputs = net(target_shape_sampled.transpose(1, 2), None, cage_only=True) # deformed = outputs["deformed"] deformed = deform_with_MVC( outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand(outputs["cage"].shape[0], -1, -1), source_mesh) b = 0 save_ply( target_shape_sampled[b].cpu().numpy(), os.path.join(opt.log_dir, subdir, "template-{}-Sb.pts".format(t_filename))) pymesh.save_mesh_raw( os.path.join(opt.log_dir, subdir, "template-{}-Sa.ply".format(t_filename)), source_mesh[0].detach().cpu(), source_face[0].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, subdir, "template-{}-Sb.ply".format(t_filename)), target_shape[b].detach().cpu(), target_face[b].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, subdir, "template-{}-Sab.ply".format(t_filename)), deformed[b].detach().cpu(), source_face[b].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, subdir, "template-{}-cage1.ply".format(t_filename)), outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, subdir, "template-{}-cage2.ply".format(t_filename)), outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu()) PairedSurreal.render_result(test_output_dir)
def test_all(net=None, subdir="test"): opt.phase = "test" dataset = build_dataset(opt) if net is None: source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float) source_face = dataset.mesh_face.unsqueeze(0) cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float) cage_face = dataset.cage_face.unsqueeze(0) net = networks.FixedSourceDeformer( opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size, template_vertices=cage_shape.transpose(1, 2), template_faces=cage_face, source_vertices=source_shape.transpose(1, 2), source_faces=source_face).cuda() load_network(net, opt.ckpt) net.eval() else: net.eval() print(net) dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=3, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id)) chamfer_distance = losses.LabeledChamferDistance(beta=0, gamma=1) mse_distance = torch.nn.MSELoss() avg_CD = 0 avg_EMD = 0 test_output_dir = os.path.join(opt.log_dir, subdir) os.makedirs(test_output_dir, exist_ok=True) with open(os.path.join(test_output_dir, "eval.txt"), "w") as f: with torch.no_grad(): source_mesh = net.source_vertices.transpose(1, 2).detach() source_face = net.source_faces.detach() for i, data in enumerate(dataloader): data = dataset.uncollate(data) target_shape, target_filename = data["target_shape"], data[ "target_file"] sample_idx = None if "sample_idx" in data: sample_idx = data["sample_idx"] outputs = net(target_shape.transpose(1, 2), sample_idx) deformed = outputs["deformed"] deformed = deform_with_MVC( outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand(outputs["cage"].shape[0], -1, -1), source_mesh) for b in range(outputs["deformed"].shape[0]): t_filename = os.path.splitext(target_filename[b])[0] target_shape_np = target_shape.detach().cpu()[b].numpy() if data["target_face"] is not None and data[ "target_mesh"] is not None: pymesh.save_mesh_raw( os.path.join( opt.log_dir, subdir, "template-{}-Sa.ply".format(t_filename)), source_mesh[0].detach().cpu(), source_face[0].detach().cpu()) pymesh.save_mesh_raw( os.path.join( opt.log_dir, subdir, "template-{}-Sb.ply".format(t_filename)), data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu()) pymesh.save_mesh_raw( os.path.join( opt.log_dir, subdir, "template-{}-Sab.ply".format(t_filename)), deformed[b].detach().cpu(), source_face[b].detach().cpu()) else: save_ply( source_mesh[0].detach().cpu(), os.path.join( opt.log_dir, subdir, "template-{}-Sa.ply".format(t_filename))) save_ply( target_shape[b].detach().cpu(), os.path.join( opt.log_dir, subdir, "template-{}-Sb.ply".format(t_filename)), normals=data["target_normals"][b].detach().cpu()) save_ply( deformed[b].detach().cpu(), os.path.join( opt.log_dir, subdir, "template-{}-Sab.ply".format(t_filename)), normals=data["target_normals"][b].detach().cpu()) pymesh.save_mesh_raw( os.path.join( opt.log_dir, subdir, "template-{}-cage1.ply".format(t_filename)), outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu()) pymesh.save_mesh_raw( os.path.join( opt.log_dir, subdir, "template-{}-cage2.ply".format(t_filename)), outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu()) log_str = "{}/{} {}".format(i, len(dataloader), t_filename) print(log_str) f.write(log_str + "\n") dataset.render_result(test_output_dir)
def test_all(opt, new_cage_shape): opt.phase = "test" opt.target_model = None print(opt.model) if opt.is_poly: source_mesh = om.read_polymesh(opt.model) else: source_mesh = om.read_trimesh(opt.model) dataset = build_dataset(opt) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False, collate_fn=tolerating_collate, num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id)) states = torch.load(opt.ckpt) if "states" in states: states = states["states"] # states["template_vertices"] = new_cage_shape.transpose(1, 2) # states["source_vertices"] = new_source.transpose(1,2) # states["source_faces"] = new_source_face # new_source_face = states["source_faces"] om.write_mesh(os.path.join(opt.log_dir, opt.subdir, "template-Sa.ply"), source_mesh) net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size, template_vertices=states["template_vertices"], template_faces=states["template_faces"].cuda(), source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda() load_network(net, states) source_points = torch.from_numpy( source_mesh.points().copy()).float().cuda().unsqueeze(0) with torch.no_grad(): # source_face = net.source_faces.detach() for i, data in enumerate(dataloader): data = dataset.uncollate(data) target_shape, target_filename = data["target_shape"], data["target_file"] logger.info("", data["target_file"][0]) sample_idx = None if "sample_idx" in data: sample_idx = data["sample_idx"] outputs = net(target_shape.transpose(1, 2), cage_only=True) if opt.d_residual: cage_offset = outputs["new_cage"]-outputs["cage"] outputs["cage"] = new_cage_shape outputs["new_cage"] = new_cage_shape+cage_offset deformed = deform_with_MVC(outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand( outputs["cage"].shape[0], -1, -1), source_points) for b in range(deformed.shape[0]): t_filename = os.path.splitext(target_filename[b])[0] source_mesh_arr = source_mesh.points() source_mesh_arr[:] = deformed[0].cpu().detach().numpy() om.write_mesh(os.path.join( opt.log_dir, opt.subdir, "template-{}-Sab.obj".format(t_filename)), source_mesh) # if data["target_face"] is not None and data["target_mesh"] is not None: # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)), # source_mesh[0].detach().cpu(), source_face[b].detach().cpu()) pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sb.ply".format(t_filename)), data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu()) # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)), # deformed[b].detach().cpu(), source_face[b].detach().cpu()) # else: # save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename))) # save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)), # normals=data["target_normals"][b].detach().cpu()) # save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)), # normals=data["target_normals"][b].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, opt.subdir, "template-{}-cage1.ply".format(t_filename)), outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(), ) pymesh.save_mesh_raw( os.path.join(opt.log_dir, opt.subdir, "template-{}-cage2.ply".format(t_filename)), outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(), ) # if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]: # deformed = optimize_lap(opt, source_mesh, deformed, source_face) # for b in range(deformed.shape[0]): # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)), # deformed[b].detach().cpu(), source_face[b].detach().cpu()) if i % 20 == 0: logger.success("[{}/{}] Done".format(i, len(dataloader))) dataset.render_result(os.path.join(opt.log_dir, opt.subdir))