def main(): argparser = argparse.ArgumentParser() argparser.add_argument("mesh_filename", type=str, help="Point cloud to reconstruct") argparser.add_argument("radius", type=float, help="Patch radius (The parameter, r, in the paper)") argparser.add_argument("padding", type=float, help="Padding factor for patches (The parameter, c, in the paper)") argparser.add_argument("min_pts_per_patch", type=int, help="Minimum number of allowed points inside a patch used to not fit to " "patches with too little data") argparser.add_argument("--output", "-o", type=str, default="out", help="Name for the output files: e.g. if you pass in --output out, the program will save " "a dense upsampled point-cloud named out.ply, and a file containing reconstruction " "metadata and model weights named out.pt. Default: out -- " "Note: the number of points per patch in the upsampled point cloud is 64 by default " "and can be set by specifying --upsamples-per-patch.") argparser.add_argument("--upsamples-per-patch", "-nup", type=int, default=8, help="*Square root* of the number of upsamples per patch to generate in the output. i.e. if " "you pass in --upsamples-per-patch 8, there will be 64 upsamples per patch.") argparser.add_argument("--angle-threshold", "-a", type=float, default=95.0, help="Threshold (in degrees) used to discard points in " "a patch whose normal is facing the wrong way.") argparser.add_argument("--local-epochs", "-nl", type=int, default=128, help="Number of fitting iterations done for each chart to its points") argparser.add_argument("--global-epochs", "-ng", type=int, default=128, help="Number of fitting iterations done to make each chart agree " "with its neighboring charts") argparser.add_argument("--learning-rate", "-lr", type=float, default=1e-3, help="Step size for gradient descent.") argparser.add_argument("--devices", "-d", type=str, default=["cuda"], nargs="+", help="A list of devices on which to partition the models for each patch. For large inputs, " "reconstruction can be memory and compute intensive. Passing in multiple devices will " "split the load across these. E.g. --devices cuda:0 cuda:1 cuda:2") argparser.add_argument("--plot", action="store_true", help="Plot the following intermediate states:. (1) patch neighborhoods, " "(2) Intermediate reconstruction before global consistency step, " "(3) Reconstruction after global consistency step. " "This flag is useful for debugging but does not scale well to large inputs.") argparser.add_argument("--interpolate", action="store_true", help="If set, then force all patches to agree with the input at overlapping points " "(i.e. the reconstruction will try to interpolate the input point cloud). " "Otherwise, we fit all patches to the average of overlapping patches at each point.") argparser.add_argument("--max-sinkhorn-iters", "-si", type=int, default=32, help="Maximum number of Sinkhorn iterations") argparser.add_argument("--sinkhorn-epsilon", "-sl", type=float, default=1e-3, help="The reciprocal (1/lambda) of the Sinkhorn regularization parameter.") argparser.add_argument("--seed", "-s", type=int, default=-1, help="Random seed to use when initializing network weights. " "If the seed not positive, a seed is selected at random.") argparser.add_argument("--exact-emd", "-e", action="store_true", help="Use exact optimal transport distance instead of sinkhorn. " "This will be slow and should not make a difference in the output") argparser.add_argument("--use-best", action="store_true", help="Use the model with the lowest loss as the final result.") argparser.add_argument("--normal-neighborhood-size", "-ns", type=int, default=64, help="Neighborhood size used to estimate the normals in the final dense point cloud. " "Default: 64") argparser.add_argument("--save-pre-cc", action="store_true", help="Save a copy of the model before the cycle consistency step") argparser.add_argument("--batch-size", type=int, default=-1, help="Split fitting MLPs into batches") args = argparser.parse_args() # We'll populate this dictionary and save it as output output_dict = { "pre_cycle_consistency_model": None, "final_model": None, "patch_uvs": None, "patch_idx": None, "patch_txs": None, "radius": args.radius, "padding": args.padding, "min_pts_per_patch": args.min_pts_per_patch, "angle_threshold": args.angle_threshold, "interpolate": args.interpolate, "global_epochs": args.global_epochs, "local_epochs": args.local_epochs, "learning_rate": args.learning_rate, "devices": args.devices, "sinkhorn_epsilon": args.sinkhorn_epsilon, "max_sinkhorn_iters": args.max_sinkhorn_iters, "seed": utils.seed_everything(args.seed), "batch_size": args.batch_size } # Read a point cloud and normals from a file, center it about its mean, and align it along its principle vectors x, n = utils.load_point_cloud_by_file_extension(args.mesh_filename, compute_normals=True) # Compute a set of neighborhood (patches) and a uv samples for each neighborhood. Store the result in a list # of pairs (uv_j, xi_j) where uv_j are 2D uv coordinates for the j^th patch, and xi_j are the indices into x of # the j^th patch. We will try to reconstruct a function phi, such that phi(uv_j) = x[xi_j]. print("Computing neighborhoods...") bbox_diag = np.linalg.norm(np.max(x, axis=0) - np.min(x, axis=0)) patch_idx, patch_uvs, patch_xs, patch_tx = compute_patches(x, n, args.radius*bbox_diag, args.padding, angle_thresh=args.angle_threshold, min_pts_per_patch=args.min_pts_per_patch) num_patches = len(patch_uvs) output_dict["patch_uvs"] = patch_uvs output_dict["patch_idx"] = patch_idx output_dict["patch_txs"] = patch_tx if args.plot: plot_patches(x, patch_idx) # Initialize one model per patch and convert the input data to a pytorch tensor print("Creating models...") if args.batch_size > 0: num_batches = int(np.ceil(num_patches / args.batch_size)) batch_size = args.batch_size print("Splitting fitting into %d batches" % num_batches) else: num_batches = 1 batch_size = num_patches phi = nn.ModuleList([MLP(2, 3) for i in range(num_patches)]) # x = torch.from_numpy(x.astype(np.float32)).to(args.device) phi_optimizers = [] phi_optimizers_devices = [] uv_optimizer = torch.optim.Adam(patch_uvs, lr=args.learning_rate) sinkhorn_loss = SinkhornLoss(max_iters=args.max_sinkhorn_iters, return_transport_matrix=True) mse_loss = nn.MSELoss() # Fit a function, phi_i, for each patch so that phi_i(patch_uvs[i]) = x[patch_idx[i]]. i.e. so that the function # phi_i "agrees" with the point cloud on each patch. # # We also store the correspondences between the uvs and points which we use later for the consistency step. The # correspondences are stored in a list, pi where pi[i] is a vector of integers used to permute the points in # a patch. pi = [None for _ in range(num_patches)] # Cache model with the lowest loss if --use-best is passed best_models = [None for _ in range(num_patches)] best_losses = [np.inf for _ in range(num_patches)] print("Training local patches...") for b in range(num_batches): print("Fitting batch %d/%d" % (b + 1, num_batches)) start_idx = b * batch_size end_idx = min((b + 1) * batch_size, num_patches) optimizer_batch = torch.optim.Adam(phi[start_idx:end_idx].parameters(), lr=args.learning_rate) phi_optimizers.append(optimizer_batch) for i in range(start_idx, end_idx): dev_i = args.devices[i % len(args.devices)] phi[i] = phi[i].to(dev_i) patch_uvs[i] = patch_uvs[i].to(dev_i) patch_xs[i] = patch_xs[i].to(dev_i) for epoch in range(args.local_epochs): optimizer_batch.zero_grad() uv_optimizer.zero_grad() # sum_loss = torch.tensor([0.0]).to(args.devices[0]) losses = [] torch.cuda.synchronize() epoch_start_time = time.time() for i in range(start_idx, end_idx): uv_i = patch_uvs[i] x_i = patch_xs[i] y_i = phi[i](uv_i) with torch.no_grad(): if args.exact_emd: M_i = pairwise_distances(x_i.unsqueeze(0), y_i.unsqueeze(0)).squeeze().cpu().squeeze().numpy() p_i = ot.emd(np.ones(x_i.shape[0]), np.ones(y_i.shape[0]), M_i) p_i = torch.from_numpy(p_i.astype(np.float32)).to(args.devices[0]) else: _, p_i = sinkhorn_loss(x_i.unsqueeze(0), y_i.unsqueeze(0)) pi_i = p_i.squeeze().max(0)[1] pi[i] = pi_i loss_i = mse_loss(x_i[pi_i].unsqueeze(0), y_i.unsqueeze(0)) if args.use_best and loss_i.item() < best_losses[i]: best_losses[i] = loss_i.item() model_copy = copy.deepcopy(phi[i]).to('cpu') best_models[i] = copy.deepcopy(model_copy.state_dict()) loss_i.backward() losses.append(loss_i) # sum_loss += loss_i.to(args.devices[0]) # sum_loss.backward() sum_loss = sum([l.item() for l in losses]) torch.cuda.synchronize() epoch_end_time = time.time() print("%d/%d: [Total = %0.5f] [Mean = %0.5f] [Time = %0.3f]" % (epoch, args.local_epochs, sum_loss, sum_loss / (end_idx - start_idx), epoch_end_time - epoch_start_time)) optimizer_batch.step() uv_optimizer.step() for i in range(start_idx, end_idx): dev_i = 'cpu' phi[i] = phi[i].to(dev_i) patch_uvs[i] = patch_uvs[i].to(dev_i) patch_xs[i] = patch_xs[i].to(dev_i) pi[i] = pi[i].to(dev_i) optimizer_batch_devices = move_optimizer_to_device(optimizer_batch, 'cpu') phi_optimizers_devices.append(optimizer_batch_devices) print("Done batch %d/%d" % (b + 1, num_batches)) print("Mean best losses:", np.mean(best_losses[i])) if args.use_best: for i, phi_i in enumerate(phi): phi_i.load_state_dict(best_models[i]) if args.save_pre_cc: output_dict["pre_cycle_consistency_model"] = copy.deepcopy(phi.state_dict()) if args.plot: raise NotImplementedError("TODO: Fix plotting code") plot_reconstruction(x, patch_uvs, patch_tx, phi, scale=1.0/args.padding) # Do a second, global, stage of fitting where we ask all patches to agree with each other on overlapping points. # If the user passed --interpolate, we ask that the patches agree on the original input points, otherwise we ask # that they agree on the average of predictions from patches overlapping a given point. if not args.interpolate: print("Computing patch means...") with torch.no_grad(): patch_xs = patch_means(pi, patch_uvs, patch_idx, patch_tx, phi, x, args.devices, num_batches) print("Training cycle consistency...") for b in range(num_batches): print("Fitting batch %d/%d" % (b + 1, num_batches)) start_idx = b * batch_size end_idx = min((b + 1) * batch_size, num_patches) for i in range(start_idx, end_idx): dev_i = args.devices[i % len(args.devices)] phi[i] = phi[i].to(dev_i) patch_uvs[i] = patch_uvs[i].to(dev_i) patch_xs[i] = patch_xs[i].to(dev_i) pi[i] = pi[i].to(dev_i) optimizer = phi_optimizers[b] move_optimizer_to_device(optimizer, phi_optimizers_devices[b]) for epoch in range(args.global_epochs): optimizer.zero_grad() uv_optimizer.zero_grad() sum_loss = torch.tensor([0.0]).to(args.devices[0]) epoch_start_time = time.time() for i in range(start_idx, end_idx): uv_i = patch_uvs[i] x_i = patch_xs[i] y_i = phi[i](uv_i) pi_i = pi[i] loss_i = mse_loss(x_i[pi_i].unsqueeze(0), y_i.unsqueeze(0)) if loss_i.item() < best_losses[i]: best_losses[i] = loss_i.item() model_copy = copy.deepcopy(phi[i]).to('cpu') best_models[i] = copy.deepcopy(model_copy.state_dict()) sum_loss += loss_i.to(args.devices[0]) sum_loss.backward() epoch_end_time = time.time() print("%d/%d: [Total = %0.5f] [Mean = %0.5f] [Time = %0.3f]" % (epoch, args.global_epochs, sum_loss.item(), sum_loss.item() / (end_idx - start_idx), epoch_end_time-epoch_start_time)) optimizer.step() uv_optimizer.step() for i in range(start_idx, end_idx): dev_i = 'cpu' phi[i] = phi[i].to(dev_i) patch_uvs[i] = patch_uvs[i].to(dev_i) patch_xs[i] = patch_xs[i].to(dev_i) pi[i] = pi[i].to(dev_i) move_optimizer_to_device(optimizer, 'cpu') print("Mean best losses:", np.mean(best_losses[i])) for i, phi_i in enumerate(phi): phi_i.load_state_dict(best_models[i]) output_dict["final_model"] = phi.state_dict() print("Generating dense point cloud...") v, n = upsample_surface(patch_uvs, patch_tx, phi, args.devices, scale=(1.0/args.padding), num_samples=args.upsamples_per_patch, normal_samples=args.normal_neighborhood_size, num_batches=num_batches, compute_normals=False) print("Saving dense point cloud...") pcu.write_ply(args.output + ".ply", v, np.zeros([], dtype=np.int32), n, np.zeros([], dtype=v.dtype)) print("Saving metadata...") torch.save(output_dict, args.output + ".pt") if args.plot: plot_reconstruction(x, patch_uvs, patch_tx, phi, scale=1.0/args.padding)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument("mesh_filename", type=str, help="Point cloud to reconstruct") argparser.add_argument("--plot", action="store_true", help="Plot the output when done training") argparser.add_argument("--local-epochs", "-nl", type=int, default=128, help="Number of local fitting iterations") argparser.add_argument("--global-epochs", "-ng", type=int, default=128, help="Number of global fitting iterations") argparser.add_argument("--learning-rate", "-lr", type=float, default=1e-3, help="Step size for gradient descent") argparser.add_argument( "--device", "-d", type=str, default="cuda", help="The device to use when fitting (either 'cpu' or 'cuda')") argparser.add_argument( "--exact-emd", "-e", action="store_true", help="Use exact optimal transport distance instead of sinkhorn") argparser.add_argument("--max-sinkhorn-iters", "-si", type=int, default=32, help="Maximum number of Sinkhorn iterations") argparser.add_argument( "--sinkhorn-epsilon", "-sl", type=float, default=1e-3, help= "The reciprocal (1/lambda) of the sinkhorn regularization parameter.") argparser.add_argument( "--output", "-o", type=str, default="out.pt", help= "Destination to save the output reconstruction. Note, the file produced by this script " "is not a mesh or a point cloud. To construct a dense point cloud, " "see export_point_cloud.py.") argparser.add_argument( "--seed", "-s", type=int, default=-1, help="Random seed to use when initializing network weights. " "If the seed not positive, a seed is selected at random.") argparser.add_argument("--use-best", action="store_true", help="Use the model with the lowest loss") argparser.add_argument("--print-every", type=int, default=16, help="Print every N epochs") args = argparser.parse_args() # We'll populate this dictionary and save it as output output_dict = { "final_model": None, "uv": None, "x": None, "transform": None, "exact_emd": args.exact_emd, "global_epochs": args.global_epochs, "local_epochs": args.local_epochs, "learning_rate": args.learning_rate, "device": args.device, "sinkhorn_epsilon": args.sinkhorn_epsilon, "max_sinkhorn_iters": args.max_sinkhorn_iters, "seed": utils.seed_everything(args.seed), } # Read a point cloud and normals from a file, center it about its mean, and align it along its principle vectors x, n = utils.load_point_cloud_by_file_extension(args.mesh_filename, compute_normals=True) # Center the point cloud about its mean and align about its principle components x, transform = transform_pointcloud(x, args.device) # Generate an initial set of UV samples in the plane uv = torch.tensor(pcu.lloyd_2d(x.shape[0]).astype(np.float32), requires_grad=True, device=args.device) # Initialize the model for the surface # phi = mlp_ultra_shallow(2, 3, hidden=8192).to(args.device) phi = MLP(2, 3).to(args.device) # phi = MLPWideAndDeep(2, 3).to(args.device) output_dict["uv"] = uv output_dict["x"] = x output_dict["transform"] = transform optimizer = torch.optim.Adam(phi.parameters(), lr=args.learning_rate) uv_optimizer = torch.optim.Adam([uv], lr=args.learning_rate) sinkhorn_loss = SinkhornLoss(max_iters=args.max_sinkhorn_iters, return_transport_matrix=True) mse_loss = nn.MSELoss() # Cache correspondences to plot them later pi = None # Cache model with the lowest loss if --use-best is passed best_model = None best_loss = np.inf for epoch in range(args.local_epochs): optimizer.zero_grad() uv_optimizer.zero_grad() epoch_start_time = time.time() y = phi(uv) with torch.no_grad(): if args.exact_emd: M = pairwise_distances( x.unsqueeze(0), y.unsqueeze(0)).squeeze().cpu().squeeze().numpy() p = ot.emd(np.ones(x.shape[0]), np.ones(x.shape[0]), M) p = torch.from_numpy(p.astype(np.float32)).to(args.device) else: _, p = sinkhorn_loss(x.unsqueeze(0), y.unsqueeze(0)) pi = p.squeeze().max(0)[1] loss = mse_loss(x[pi].unsqueeze(0), y.unsqueeze(0)) loss.backward() if args.use_best and loss.item() < best_loss: best_loss = loss.item() best_model = copy.deepcopy(phi.state_dict()) epoch_end_time = time.time() if epoch % args.print_every == 0: print("%d/%d: [Loss = %0.5f] [Time = %0.3f]" % (epoch, args.local_epochs, loss.item(), epoch_end_time - epoch_start_time)) optimizer.step() uv_optimizer.step() if args.use_best: phi.load_state_dict(best_model) output_dict["final_model"] = copy.deepcopy(phi.state_dict()) torch.save(output_dict, args.output) if args.plot: plot_reconstruction(uv, x, transform, phi, pad=1.0) plot_correspondences(phi, uv, x, pi)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument("mesh_filename", type=str, help="Point cloud to reconstruct") argparser.add_argument( "radius", type=float, help="Patch radius (The parameter, r, in the paper)") argparser.add_argument( "padding", type=float, help="Padding factor for patches (The parameter, c, in the paper)") argparser.add_argument("--plot", action="store_true", help="Plot the output when done training") argparser.add_argument( "--angle-threshold", "-a", type=float, default=95.0, help="Threshold (in degrees) used to discard points in " "a patch whose normal is facing the wrong way.") argparser.add_argument("--local-epochs", "-nl", type=int, default=512, help="Number of local fitting iterations") argparser.add_argument("--global-epochs", "-ng", type=int, default=1024, help="Number of global fitting iterations") argparser.add_argument("--learning-rate", "-lr", type=float, default=1e-3, help="Step size for gradient descent") argparser.add_argument( "--device", "-d", type=str, default="cuda", help="The device to use when fitting (either 'cpu' or 'cuda')") argparser.add_argument( "--interpolate", action="store_true", help= "If set, then force all patches to agree with the input at overlapping points. " "Otherwise, we fit all patches to the average of overlapping patches at each point." ) argparser.add_argument("--max-sinkhorn-iters", "-si", type=int, default=32, help="Maximum number of Sinkhorn iterations") argparser.add_argument( "--sinkhorn-epsilon", "-sl", type=float, default=1e-3, help= "The reciprocal (1/lambda) of the sinkhorn regularization parameter.") argparser.add_argument( "--output", "-o", type=str, default="out.pt", help= "Destination to save the output reconstruction. Note, the file produced by this script " "is not a mesh or a point cloud. To construct a dense point cloud, see upsample.py." ) argparser.add_argument( "--seed", "-s", type=int, default=-1, help="Random seed to use when initializing network weights. " "If the seed not positive, a seed is selected at random.") argparser.add_argument( "--exact-emd", "-e", action="store_true", help="Use exact optimal transport distance instead of sinkhorn") argparser.add_argument("--use-best", action="store_true", help="Use the model with the lowest loss") args = argparser.parse_args() # We'll populate this dictionary and save it as output output_dict = { "pre_cycle_consistency_model": None, "final_model": None, "patch_uvs": None, "patch_idx": None, "patch_txs": None, "interpolate": args.interpolate, "global_epochs": args.global_epochs, "local_epochs": args.local_epochs, "learning_rate": args.learning_rate, "device": args.device, "sinkhorn_epsilon": args.sinkhorn_epsilon, "max_sinkhorn_iters": args.max_sinkhorn_iters, "seed": utils.seed_everything(args.seed), } # Read a point cloud and normals from a file, center it about its mean, and align it along its principle vectors x, n = utils.load_point_cloud_by_file_extension(args.mesh_filename, compute_normals=True) # Compute a set of neighborhood (patches) and a uv samples for each neighborhood. Store the result in a list # of pairs (uv_j, xi_j) where uv_j are 2D uv coordinates for the j^th patch, and xi_j are the indices into x of # the j^th patch. We will try to reconstruct a function phi, such that phi(uv_j) = x[xi_j]. bbox_diag = np.linalg.norm(np.max(x, axis=0) - np.min(x, axis=0)) patch_idx, patch_uvs, patch_xs, patch_tx = compute_patches( x, n, args.radius * bbox_diag, args.padding, args.angle_threshold, args.device) num_patches = len(patch_uvs) output_dict["patch_uvs"] = patch_uvs output_dict["patch_idx"] = patch_idx output_dict["patch_txs"] = patch_tx if args.plot: plot_patches(x, patch_idx) # Initialize one model per patch and convert the input data to a pytorch tensor phi = nn.ModuleList( [MLP(2, 3).to(args.device) for _ in range(num_patches)]) x = torch.from_numpy(x.astype(np.float32)).to(args.device) optimizer = torch.optim.Adam(phi.parameters(), lr=args.learning_rate) uv_optimizer = torch.optim.Adam(patch_uvs, lr=args.learning_rate) sinkhorn_loss = SinkhornLoss(max_iters=args.max_sinkhorn_iters, return_transport_matrix=True) mse_loss = nn.MSELoss() # Fit a function, phi_i, for each patch so that phi_i(patch_uvs[i]) = x[patch_idx[i]]. i.e. so that the function # phi_i "agrees" with the point cloud on each patch. # # We also store the correspondences between the uvs and points which we use later for the consistency step. The # correspondences are stored in a list, pi where pi[i] is a vector of integers used to permute the points in # a patch. pi = [None for _ in range(num_patches)] # Cache model with the lowest loss if --use-best is passed best_models = [None for _ in range(num_patches)] best_losses = [np.inf for _ in range(num_patches)] for epoch in range(args.local_epochs): optimizer.zero_grad() uv_optimizer.zero_grad() sum_loss = torch.tensor([0.0]).to(args.device) epoch_start_time = time.time() for i in range(num_patches): uv_i = patch_uvs[i] x_i = patch_xs[i] y_i = phi[i](uv_i) with torch.no_grad(): if args.exact_emd: M_i = pairwise_distances( x_i.unsqueeze(0), y_i.unsqueeze(0)).squeeze().cpu().squeeze().numpy() p_i = ot.emd(np.ones(x_i.shape[0]), np.ones(y_i.shape[0]), M_i) p_i = torch.from_numpy(p_i.astype(np.float32)).to( args.device) else: _, p_i = sinkhorn_loss(x_i.unsqueeze(0), y_i.unsqueeze(0)) pi_i = p_i.squeeze().max(0)[1] pi[i] = pi_i loss_i = mse_loss(x_i[pi_i].unsqueeze(0), y_i.unsqueeze(0)) if args.use_best and loss_i.item() < best_losses[i]: best_losses[i] = loss_i best_models[i] = copy.deepcopy(phi[i].state_dict()) sum_loss += loss_i sum_loss.backward() epoch_end_time = time.time() print("%d/%d: [Total = %0.5f] [Mean = %0.5f] [Time = %0.3f]" % (epoch, args.local_epochs, sum_loss.item(), sum_loss.item() / num_patches, epoch_end_time - epoch_start_time)) optimizer.step() uv_optimizer.step() if args.use_best: for i, phi_i in enumerate(phi): phi_i.load_state_dict(best_models[i]) output_dict["pre_cycle_consistency_model"] = copy.deepcopy( phi.state_dict()) if args.plot: plot_reconstruction(x, patch_uvs, patch_tx, phi, scale=0.8) # Do a second, global, stage of fitting where we ask all patches to agree with each other on overlapping points. # If the user passed --interpolate, we ask that the patches agree on the original input points, otherwise we ask # that they agree on the average of predictions from patches overlapping a given point. for epoch in range(args.global_epochs): optimizer.zero_grad() uv_optimizer.zero_grad() sum_loss = torch.tensor([0.0]).to(args.device) epoch_start_time = time.time() for i in range(num_patches): uv_i = patch_uvs[i] x_i = patch_xs[i] y_i = phi[i](uv_i) pi_i = pi[i] loss_i = mse_loss(x_i[pi_i].unsqueeze(0), y_i.unsqueeze(0)) if loss_i.item() < best_losses[i]: best_losses[i] = loss_i best_models[i] = copy.deepcopy(phi[i].state_dict()) sum_loss += loss_i sum_loss.backward() epoch_end_time = time.time() print("%d/%d: [Total = %0.5f] [Mean = %0.5f] [Time = %0.3f]" % (epoch, args.global_epochs, sum_loss.item(), sum_loss.item() / num_patches, epoch_end_time - epoch_start_time)) optimizer.step() uv_optimizer.step() for i, phi_i in enumerate(phi): phi_i.load_state_dict(best_models[i]) output_dict["final_model"] = copy.deepcopy(phi.state_dict()) torch.save(output_dict, args.output) if args.plot: plot_reconstruction(x, patch_uvs, patch_tx, phi, scale=0.8)