def train(): dataset = build_dataset(opt) dataloader = torch.utils.data.DataLoader( dataset, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id)) source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float) source_face = dataset.mesh_face.unsqueeze(0) cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float) cage_face = dataset.cage_face.unsqueeze(0) mesh = Mesh(vertices=cage_shape[0], faces=cage_face[0]) build_gemm(mesh, cage_face[0]) cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda() cage_edges = edge_vertex_indices(cage_face[0]) # network net = networks.FixedSourceDeformer( opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size, template_vertices=cage_shape.transpose(1, 2), template_faces=cage_face, source_vertices=source_shape.transpose(1, 2), source_faces=source_face).cuda() print(net) net.apply(weights_init) if opt.ckpt: load_network(net, opt.ckpt) net.train() all_losses = losses.AllLosses(opt) # optimizer optimizer = torch.optim.Adam([{ 'params': net.nd_decoder.parameters() }, { "params": net.encoder.parameters() }], lr=opt.lr) # train os.makedirs(opt.log_dir, exist_ok=True) shutil.copy2(__file__, opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "network2.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"), opt.log_dir) pymesh.save_mesh_raw( os.path.join(opt.log_dir, "t{:06d}_Sa.ply".format(0)), net.source_vertices[0].transpose(0, 1).detach().cpu().numpy(), net.source_faces[0].detach().cpu()) pymesh.save_mesh_raw( os.path.join(opt.log_dir, "t{:06d}_template.ply".format(0)), net.template_vertices[0].transpose(0, 1).detach().cpu().numpy(), net.template_faces[0].detach().cpu()) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, max(int(opt.nepochs * 0.75), 1), gamma=0.5, last_epoch=-1) # train net.train() t = 0 start_epoch = 0 warmed_up = False mvc_weight = opt.mvc_weight opt.mvc_weight = 0 os.makedirs(opt.log_dir, exist_ok=True) running_avg_loss = -1 log_file = open(os.path.join(opt.log_dir, "loss_log.txt"), "a") log_interval = min(max(len(dataloader) // 5, 50), 200) save_interval = max(opt.nepochs // 10, 1) with torch.autograd.detect_anomaly(): if opt.epoch: start_epoch = opt.epoch % opt.nepochs t += start_epoch * len(dataloader) for epoch in range(start_epoch, opt.nepochs): for epoch_t, data in enumerate(dataloader): progress = epoch_t / len(dataloader) + epoch warming_up = progress < opt.warmup_epochs if (opt.deform_template or opt.optimize_template) and ( progress >= opt.warmup_epochs) and (not warmed_up): if opt.deform_template: optimizer.add_param_group({ 'params': net.nc_decoder.parameters(), 'lr': 0.1 * opt.lr }) if opt.optimize_template: optimizer.add_param_group({ 'params': net.template_vertices, 'lr': 0.1 * opt.lr }) warmed_up = True # start to compute mvc weight opt.mvc_weight = mvc_weight save_network(net, opt.log_dir, network_label="net", epoch_label="warmed_up") ############# get data ########### data = dataset.uncollate(data) data["cage_edge_points"] = cage_edge_points data["cage_edges"] = cage_edges data["source_shape"] = net.source_vertices.detach() data["source_face"] = net.source_faces.detach() ############# run network ########### optimizer.zero_grad() target_shape_t = data["target_shape"].transpose(1, 2) sample_idx = None if "sample_idx" in data: sample_idx = data["sample_idx"] if data["source_normals"] is not None: data["source_normals"] = torch.gather( data["source_normals"], 1, sample_idx.unsqueeze(-1).expand(-1, -1, 3)) outputs = net(target_shape_t, sample_idx) if opt.sfnormal_weight > 0 and ("source_mesh" in data and "source_mesh" is not None): if outputs["deformed"].shape[1] == data[ "source_mesh"].shape[1]: outputs["deformed_hr"] = outputs["deformed"] else: outputs["deformed_hr"] = deform_with_MVC( outputs["cage"].expand( data["source_mesh"].shape[0], -1, -1).detach(), outputs["new_cage"], outputs["cage_face"].expand( data["source_mesh"].shape[0], -1, -1), data["source_mesh"]) data["source_shape"] = outputs["source_shape"] ############# get losses ########### current_loss = all_losses(data, outputs, progress) loss_sum = torch.sum( torch.stack([v for v in current_loss.values()], dim=0)) if running_avg_loss < 0: running_avg_loss = loss_sum else: running_avg_loss = running_avg_loss + ( loss_sum.item() - running_avg_loss) / (t + 1) if (t % log_interval == 0) or (loss_sum > 10 * running_avg_loss): log_str = "warming up {} e {:03d} t {:05d}: {}".format( not warmed_up, epoch, t, ", ".join([ "{} {:.3g}".format(k, v.mean().item()) for k, v in current_loss.items() ])) print(log_str) log_file.write(log_str + "\n") log_outputs(opt, t, outputs, data) # save_ply(data["target_shape"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(t))) # save_ply(outputs["deformed"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sab.ply".format(t))) # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage1.ply".format(t)), # outputs["cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True) # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage2.ply".format(t)), # outputs["new_cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True) if loss_sum > 100 * running_avg_loss: logger.info( "loss ({}) > 10*running_average_loss ({}). Skip without update." .format(loss_sum, 5 * running_avg_loss)) torch.cuda.empty_cache() continue loss_sum.backward() if opt.alternate_cd: optimize_C = (progress > opt.warmup_epochs) and ( t % (opt.c_step + opt.d_step)) > opt.d_step if optimize_C: net.nd_decoder.zero_grad() net.encoder.zero_grad() else: try: net.nc_decoder.zero_grad() except AttributeError: net.template_vertices.grad.zero_() # clamp_gradient_norm(net, 1) optimizer.step() if (t + 1) % 500 == 0: save_network(net, opt.log_dir, network_label="net", epoch_label="latest") t += 1 if (epoch + 1) % save_interval == 0: save_network(net, opt.log_dir, network_label="net", epoch_label=epoch) scheduler.step() log_file.close() save_network(net, opt.log_dir, network_label="net", epoch_label="final") test_all(net=net)
def train(): dataset = build_dataset(opt) dataloader = torch.utils.data.DataLoader( dataset, batch_size=opt.batch_size, shuffle=True, drop_last=True, collate_fn=tolerating_collate, num_workers=2, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id)) if opt.dim == 3: # cage (1,N,3) init_cage_V, init_cage_Fs = loadInitCage([opt.template]) cage_V_t = init_cage_V.transpose(1, 2).detach().cuda() cage_edge_points_list = [] cage_edges_list = [] for F in init_cage_Fs: mesh = Mesh(vertices=init_cage_V[0], faces=F[0]) build_gemm(mesh, F[0]) cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda() cage_edge_points_list.append(cage_edge_points) cage_edges_list = [edge_vertex_indices(F[0])] else: init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg) init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V], dtype=torch.float).unsqueeze(0) cage_V_t = init_cage_V.transpose(1, 2).detach().cuda() init_cage_Fs = [ torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1, -1).cuda() ] # network net = networks.NetworkFull( opt, dim=opt.dim, bottleneck_size=opt.bottleneck_size, template_vertices=cage_V_t, template_faces=init_cage_Fs[-1], ).cuda() net.apply(weights_init) if opt.ckpt: load_network(net, opt.ckpt) all_losses = losses.AllLosses(opt) # optimizer optimizer = torch.optim.Adam([{ "params": net.encoder.parameters() }, { "params": net.nd_decoder.parameters() }, { "params": net.merger.parameters() }], lr=opt.lr) if opt.full_net: optimizer.add_param_group({ 'params': net.nc_decoder.parameters(), 'lr': 0.1 * opt.lr }) if opt.optimize_template: optimizer.add_param_group({ 'params': net.template_vertices, 'lr': opt.lr }) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs * 0.4), gamma=0.1, last_epoch=-1) # train net.train() start_epoch = 0 t = 0 steps_C = 20 steps_D = 20 # train os.makedirs(opt.log_dir, exist_ok=True) shutil.copy2(__file__, opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "networks.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"), opt.log_dir) shutil.copy2(os.path.join(os.path.dirname(__file__), "option.py"), opt.log_dir) print(net) log_file = open(os.path.join(opt.log_dir, "training_log.txt"), "a") log_file.write(str(net) + "\n") log_interval = max(len(dataloader) // 5, 50) save_interval = max(opt.nepochs // 10, 1) running_avg_loss = -1 with torch.autograd.detect_anomaly(): if opt.epoch: start_epoch = opt.epoch % opt.nepochs t += start_epoch * len(dataloader) for epoch in range(start_epoch, opt.nepochs): for t_epoch, data in enumerate(dataloader): warming_up = epoch < opt.warmup_epochs progress = t_epoch / len(dataloader) + epoch optimize_C = (t % (steps_C + steps_D)) > steps_D ############# get data ########### data = dataset.uncollate(data) data = crisscross_input(data) if opt.dim == 3: data["cage_edge_points"] = cage_edge_points_list[-1] data["cage_edges"] = cage_edges_list[-1] source_shape, target_shape = data["source_shape"], data[ "target_shape"] ############# blending ############ if opt.blend_style: blend_alpha = torch.rand( (source_shape.shape[0], 1), dtype=torch.float32).to(device=source_shape.device) else: blend_alpha = 1.0 data["alpha"] = blend_alpha ############# run network ########### optimizer.zero_grad() # optimizer_C.zero_grad() # optimizer_D.zero_grad() source_shape_t = source_shape.transpose(1, 2) target_shape_t = target_shape.transpose(1, 2) outputs = net(source_shape_t, target_shape_t, data["alpha"]) ############# get losses ########### current_loss = all_losses(data, outputs, progress) loss_sum = torch.sum( torch.stack([v for v in current_loss.values()], dim=0)) if running_avg_loss < 0: running_avg_loss = loss_sum else: running_avg_loss = running_avg_loss + ( loss_sum.item() - running_avg_loss) / (t + 1) if (t % log_interval == 0) or (loss_sum > 5 * running_avg_loss): log_str = "warming up {} e {:03d} t {:05d}: {}".format( warming_up, epoch, t, ", ".join([ "{} {:.3g}".format(k, v.mean().item()) for k, v in current_loss.items() ])) print(log_str) log_file.write(log_str + "\n") log_outputs(opt, t, outputs, data) if loss_sum > 100 * running_avg_loss: logger.info( "loss ({}) > 5*running_average_loss ({}). Skip without update." .format(loss_sum, 5 * running_avg_loss)) torch.cuda.empty_cache() continue loss_sum.backward() if epoch < opt.warmup_epochs: try: net.nc_decoder.zero_grad() net.encoder.zero_grad() except AttributeError: net.template_vertices.grad.zero_() if opt.alternate_cd: optimize_C = (epoch > opt.warmup_epochs) and ( epoch % (opt.c_epoch + opt.d_epoch)) > opt.d_epoch if optimize_C: net.nd_decoder.zero_grad() else: try: net.encoder.zero_grad() net.nc_decoder.zero_grad() except AttributeError: net.template_vertices.grad.zero_() clamp_gradient(net, 0.1) optimizer.step() if (t + 1) % 500 == 0: save_network(net, opt.log_dir, network_label="net", epoch_label="latest") t += 1 if (epoch + 1) % save_interval == 0: save_network(net, opt.log_dir, network_label="net", epoch_label=epoch) scheduler.step() if opt.eval: try: test(net=net, save_subdir="epoch_{}".format(epoch)) except Exception as e: traceback.print_exc(file=sys.stdout) logger.warn("Failed to run test", str(e)) log_file.close() save_network(net, opt.log_dir, network_label="net", epoch_label="final") test(net=net)