Example #1
0
def log_outputs(opt, step, all_outputs, all_inputs):
    # Source
    color = all_inputs["source_shape"][:,:,1].cpu().numpy()
    save_ply_property(os.path.join(opt.log_dir,"step-{:06d}-Sa.ply".format(step)), all_inputs["source_shape"][0].detach().cpu().numpy(), color[0], cmap_name="rainbow")
    # Target
    save_ply_property(os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(step)), all_inputs["target_shape"][0].detach().cpu().numpy(), color[0], cmap_name="rainbow")
    for batch in range(0, all_outputs["cage"].shape[0], opt.batch_size):
        if batch // opt.batch_size == 0:
            tag = "StoT"
        elif batch // opt.batch_size == 2:
            tag = "StoS"
        elif batch // opt.batch_size == 1:
            tag = "TtoS"
        elif batch // opt.batch_size == 3:
            tag = "TtoT"

        # deformed and cage
        save_ply_property(os.path.join(opt.log_dir,"step-{:06d}-{}-Sab.ply".format(step, tag)),
                    all_outputs["deformed"][batch].detach().cpu().numpy(), color[batch], cmap_name="rainbow")
        write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-{}-cage1.ply".format(step, tag)),
                    all_outputs["cage"][batch].detach().cpu(), all_outputs["cage_face"][0].detach().cpu(), binary=True)
        write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-{}-cage2.ply".format(step, tag)),
                    all_outputs["new_cage"][batch].detach().cpu(), all_outputs["cage_face"][0].detach().cpu(), binary=True)

        # if using network2
        if "cage_surface" in all_outputs:
            save_ply(os.path.join(opt.log_dir,"step-{:06d}-{}-cage_surface1.ply".format(step, tag)), all_outputs["cage_surface"][batch].detach().cpu().numpy())
            save_ply(os.path.join(opt.log_dir,"step-{:06d}-{}-cage_surface2.ply".format(step, tag)), all_outputs["new_cage_surface"][batch].detach().cpu().numpy())
Example #2
0
def writeCameras(scene, savePath):
    position = torch.cat([c.position for c in scene.cameras], dim=0)
    normal = torch.cat([c.rotation[:, :, 2] for c in scene.cameras], dim=0)
    save_ply(position.cpu().numpy(), savePath, normals=normal.cpu().numpy())
    # SOURCE_PATH = os.path.join(DATA_DIR, "{:d}".format(DIGIT), SOURCE_NAME+".ply")
    SOURCE_PATH = "/home/yifan/Documents/Cage/scripts/wlop/build/gingerbreadman.ply"
    CAGE_PATH = "/home/yifan/Documents/Cage/scripts/wlop/build/gingerbreadman_cage.ply"

    # polygon_list = [
    #     (-0.523185483870968,	0.553246753246753),
    #     (-0.644153225806452,	-0.101298701298701),
    #     (-0.166330645161290,	-0.218181818181818),
    #     (0.190524193548387,	-0.381818181818182),
    #     (0.450604838709678,	-0.553246753246754),
    #     (0.656250000000000,	0),
    #     (0.335685483870968,	0.225974025974026),
    #     (-0.154233870967742,	0.444155844155844),
    # ]

    # polygon = torch.tensor([(x, y) for x, y in polygon_list], dtype=torch.float).unsqueeze(0).transpose(1, 2)

    source = torch.tensor(load(SOURCE_PATH)[:,:2], dtype=torch.float).unsqueeze(0).transpose(1,2)
    save_ply(source[0].transpose(0,1).numpy(), "../vanilla_data/{}/{}.ply".format(SOURCE_NAME, SOURCE_NAME))
    polygon = torch.tensor(load(CAGE_PATH))[:,:2].unsqueeze(0).transpose(1,2)
    save_ply(polygon[0].transpose(0,1).numpy(), "../vanilla_data/{}/{}-cage.ply".format(SOURCE_NAME, SOURCE_NAME), binary=False)
    weights = mean_value_coordinates(source, polygon)
    # perturb
    for i in range(PERTURB_EPOCH):
        new_polygon = polygon
        for k in range(PERTURB_ITER):
            new_polygon = perturb(new_polygon, RADIUS_PERTURB, ANGLE_PERTURB)
            # (B,2,M,N) * (B,2,M,1) -> (B,2,N)
            deformed = torch.sum(weights.unsqueeze(1)*new_polygon.unsqueeze(-1), dim=2)
            save_ply(deformed[0].transpose(0,1), "../vanilla_data/{}/{}-{}.ply".format(SOURCE_NAME, SOURCE_NAME, i*PERTURB_ITER+k))
            save_ply(new_polygon[0].transpose(0,1).numpy(), "../vanilla_data/{}/{}-{}-cage.ply".format(SOURCE_NAME, SOURCE_NAME, i*PERTURB_ITER+k), binary=False)
    assert(len(points_paths) > 0), "Found no point clouds in with path {}".format(points_paths)
    points_relpaths = None
    if len(points_paths) > 1:
        points_dir = os.path.commonpath(points_paths)
        points_relpaths = [os.path.relpath(p, points_dir) for p in points_paths]
    else:
        points_relpaths = [os.path.basename(p) for p in points_paths]

    torch.manual_seed(24)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(24)

    view_sampler = SphericalSampler(300, 'SPIRAL')
    points = view_sampler.points
    save_ply(points,'example_data/pointclouds/spiral_300.ply',normals=points)


    scene = readScene(opt.source, device="cpu")
    opt.genCamera = 300
    if opt.genCamera > 0:
        camSampler = CameraSampler(opt.genCamera, opt.camOffset, opt.camFocalLength, points=scene.cloud.localPoints,
                                   camWidth=opt.width, camHeight=opt.height, filename="example_data/pointclouds/spiral_300.ply")
        camSampler.closer = False
    with torch.no_grad():
        splatter = createSplatter(opt, scene=scene)
        #splatter.shading = 'depth'
        if opt.genCamera > 0:
            cameras = []
            for i in range(opt.genCamera):
                cam = next(camSampler)
Example #5
0
def test(net=None, subdir="test"):
    opt.phase = "test"
    if isinstance(opt.target_model, str):
        opt.target_model = [opt.target_model]

    if net is None:
        states = torch.load(opt.ckpt)
        if "states" in states:
            states = states["states"]
        if opt.template:
            cage_shape, cage_face = read_trimesh(opt.template)
            cage_shape = torch.from_numpy(
                cage_shape[:, :3]).unsqueeze(0).float()
            cage_face = torch.from_numpy(cage_face).unsqueeze(0).long()
            states["template_vertices"] = cage_shape.transpose(1, 2)
            states["template_faces"] = cage_face

        if opt.source_model:
            source_shape, source_face = read_trimesh(opt.source_model)
            source_shape = torch.from_numpy(
                source_shape[:, :3]).unsqueeze(0).float()
            source_face = torch.from_numpy(source_face).unsqueeze(0).long()
            states["source_vertices"] = source_shape.transpose(1, 2)
            states["source_faces"] = source_shape

        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=states["template_vertices"],
            template_faces=states["template_faces"],
            source_vertices=states["source_vertices"],
            source_faces=states["source_faces"]).cuda()

        load_network(net, states)
        net = net.cuda()
        net.eval()
    else:
        net.eval()

    print(net)

    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with torch.no_grad():
        for target_model in opt.target_model:
            assert (os.path.isfile(target_model))
            target_face = None
            target_shape, target_face = read_trimesh(target_model)
            # target_shape = read_ply(target_model)[:,:3]
            # target_shape, _, scale = normalize_to_box(target_shape)
            # normalize acording to height y axis
            # target_shape = target_shape/2*1.7
            target_shape = torch.from_numpy(
                target_shape[:, :3]).cuda().float().unsqueeze(0)
            if target_face is None:
                target_face = net.source_faces
            else:
                target_face = torch.from_numpy(
                    target_face).cuda().long().unsqueeze(0)
            t_filename = os.path.splitext(os.path.basename(target_model))[0]

            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()

            # furthest sampling
            target_shape_sampled = furthest_point_sample(
                target_shape, net.source_vertices.shape[2], NCHW=False)[1]
            # target_shape_sampled = (target_shape[:, np.random.permutation(target_shape.shape[1]), :]).contiguous()
            outputs = net(target_shape_sampled.transpose(1, 2),
                          None,
                          cage_only=True)
            # deformed = outputs["deformed"]

            deformed = deform_with_MVC(
                outputs["cage"], outputs["new_cage"],
                outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                            -1), source_mesh)

            b = 0

            save_ply(
                target_shape_sampled[b].cpu().numpy(),
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.pts".format(t_filename)))
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sa.ply".format(t_filename)),
                source_mesh[0].detach().cpu(), source_face[0].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.ply".format(t_filename)),
                target_shape[b].detach().cpu(), target_face[b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sab.ply".format(t_filename)),
                deformed[b].detach().cpu(), source_face[b].detach().cpu())

            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage1.ply".format(t_filename)),
                outputs["cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage2.ply".format(t_filename)),
                outputs["new_cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())

    PairedSurreal.render_result(test_output_dir)
Example #6
0
def test_all(net=None, subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if net is None:
        source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
        source_face = dataset.mesh_face.unsqueeze(0)
        cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
        cage_face = dataset.cage_face.unsqueeze(0)
        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_shape.transpose(1, 2),
            template_faces=cage_face,
            source_vertices=source_shape.transpose(1, 2),
            source_faces=source_face).cuda()

        load_network(net, opt.ckpt)
        net.eval()
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=3,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    chamfer_distance = losses.LabeledChamferDistance(beta=0, gamma=1)
    mse_distance = torch.nn.MSELoss()
    avg_CD = 0
    avg_EMD = 0
    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                target_shape, target_filename = data["target_shape"], data[
                    "target_file"]

                sample_idx = None
                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                outputs = net(target_shape.transpose(1, 2), sample_idx)
                deformed = outputs["deformed"]

                deformed = deform_with_MVC(
                    outputs["cage"], outputs["new_cage"],
                    outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                                -1), source_mesh)

                for b in range(outputs["deformed"].shape[0]):
                    t_filename = os.path.splitext(target_filename[b])[0]
                    target_shape_np = target_shape.detach().cpu()[b].numpy()
                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sa.ply".format(t_filename)),
                            source_mesh[0].detach().cpu(),
                            source_face[0].detach().cpu())
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sb.ply".format(t_filename)),
                            data["target_mesh"][b].detach().cpu(),
                            data["target_face"][b].detach().cpu())
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sab.ply".format(t_filename)),
                            deformed[b].detach().cpu(),
                            source_face[b].detach().cpu())
                    else:
                        save_ply(
                            source_mesh[0].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sa.ply".format(t_filename)))
                        save_ply(
                            target_shape[b].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sb.ply".format(t_filename)),
                            normals=data["target_normals"][b].detach().cpu())
                        save_ply(
                            deformed[b].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sab.ply".format(t_filename)),
                            normals=data["target_normals"][b].detach().cpu())

                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, subdir,
                            "template-{}-cage1.ply".format(t_filename)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][b].detach().cpu())
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, subdir,
                            "template-{}-cage2.ply".format(t_filename)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][b].detach().cpu())

                    log_str = "{}/{} {}".format(i, len(dataloader), t_filename)
                    print(log_str)
                    f.write(log_str + "\n")

    dataset.render_result(test_output_dir)
Example #7
0
def optimize(opt):
    """
    weights are the same with the original source mesh
    target=net(old_source)
    """
    # load new target
    if opt.is_poly:
        target_mesh = om.read_polymesh(opt.model)
    else:
        target_mesh = om.read_trimesh(opt.model)
    target_shape_arr = target_mesh.points()
    target_shape = target_shape_arr.copy()
    target_shape = torch.from_numpy(
        target_shape[:, :3].astype(np.float32)).cuda()
    target_shape.unsqueeze_(0)

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    cage_v = states["template_vertices"].transpose(1, 2).cuda()
    cage_f = states["template_faces"].cuda()
    shape_v = states["source_vertices"].transpose(1, 2).cuda()
    shape_f = states["source_faces"].cuda()

    if os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path))
        import pandas as pd
        new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        new_label_name = new_label.iloc[:,5].tolist()
        new_to_orig_idx = []
        for i, name in enumerate(new_label_name):
            matched_idx = orig_label_name[orig_label_name==name].index
            if matched_idx.size == 1:
                new_to_orig_idx.append((i, matched_idx[0]))
        new_to_orig_idx = np.array(new_to_orig_idx)
        if new_label.shape[1] == 10:
            new_vidx = new_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,0]]
            target_points = target_shape[:, new_vidx, :]
        else:
            new_label_points = torch.from_numpy(new_label.iloc[:,6:9].to_numpy().astype(np.float32))
            target_points = new_label_points.unsqueeze(0).cuda()
            target_points, new_vidx, _ = faiss_knn(1, target_points, target_shape, NCHW=False)
            target_points = target_points.squeeze(2) # B,N,3
            new_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            new_label.to_csv(new_label_path, sep=" ", header=[str(new_label.shape[0])]+[""]*(new_label.shape[1]-1), index=False)
            target_points = target_points[:, new_to_orig_idx[:,0], :]

        target_points = target_points.cuda()
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).float()
        if orig_label.shape[1] == 10:
            orig_vidx = orig_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,1]]
            source_points = source_shape[:, orig_vidx, :]
        else:
            orig_label_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = orig_label_points.unsqueeze(0)
            # find the closest point on the original meshes
            source_points, new_vidx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            orig_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])]+[""]*(orig_label.shape[1]-1), index=False)
            source_points = source_points[:,new_to_orig_idx[:,1],:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
        source_points = source_points.cuda()
        # # shift target so that the belly match
        # try:
        #     orig_bellyUp_idx = orig_label_name[orig_label_name=="bellUp"].index[0]
        #     orig_bellyUp = orig_label_points[orig_bellyUp_idx, :]
        #     new_bellyUp_idx = [i for i, i2 in new_to_orig_idx if i2==orig_bellyUp_idx][0]
        #     new_bellyUp = new_label_points[new_bellyUp_idx,:]
        #     target_points += (orig_bellyUp - new_bellyUp)
        # except Exception as e:
        #     logger.warn("Couldn\'t match belly to belly")
        #     traceback.print_exc(file=sys.stdout)

        # source_points[0] = center_bounding_box(source_points[0])[0]
    elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        logger.info("Assuming Faust model")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {}".format(orig_label_path))
        import pandas as pd
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float()
        if orig_label.shape[1] == 10:
            idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long()
            source_points = source_shape[:,idx,:]
            target_points = target_shape[:,idx,:]
        else:
            source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = source_points.unsqueeze(0).cuda()
            # find the closest point on the original meshes
            source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            idx = idx.squeeze(-1)
            target_points = target_shape[:,idx,:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
    elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]:
        logger.info("No correspondence provided, assuming registered Faust models")
        # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
        corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda()
        target_points = torch.index_select(target_shape, 1, corresp_v)
        source_points = torch.index_select(shape_v, 1, corresp_v)

    target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0])
    _, _, source_scale = center_bounding_box(shape_v[0])
    target_scale_factor = (source_scale/target_scale)[1]
    target_shape *= target_scale_factor
    target_points -= target_center
    target_points = (target_points*target_scale_factor).detach()
    # make sure test use the normalized
    target_shape_arr[:] = target_shape[0].cpu().numpy()
    om.write_mesh(os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj"), target_mesh)
    opt.model = os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj")
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.obj"),
                         shape_v[0].cpu().numpy(), shape_f[0].cpu().numpy())
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "cage-initial.obj"),
                         cage_v[0].cpu().numpy(), cage_f[0].cpu().numpy())
    save_ply(target_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "target_points.ply"))
    save_ply(source_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "source_points.ply"))
    logger.info("Optimizing for {} corresponding vertices".format(
        target_points.shape[1]))

    cage_init = cage_v.clone().detach()
    lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True,
                                 use_norm=True, consistent_topology=True, precompute_L=True)
    mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0)
    cage_v.requires_grad_(True)
    optimizer = torch.optim.Adam([cage_v], lr=opt.lr, betas=(0.5, 0.9))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs*0.4), gamma=0.5, last_epoch=-1)

    if opt.dim == 3:
        weights_ref = mean_value_coordinates_3D(
            source_points, cage_init, cage_f, verbose=False)
    else:
        raise NotImplementedError

    for t in range(opt.nepochs):
        optimizer.zero_grad()
        weights = mean_value_coordinates_3D(
            target_points, cage_v, cage_f, verbose=False)
        loss_mvc = torch.mean((weights-weights_ref)**2)
        # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
        reg = 0
        if opt.clap_weight > 0:
            reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight
            reg = reg.mean()
        if opt.mvc_weight > 0:
            reg += mvc_reg_loss(weights)*opt.mvc_weight

        # weight regularizer with the shape difference
        # dist = torch.sum((source_points - target_points)**2, dim=-1)
        # weights = torch.exp(-dist)
        # reg = reg*weights*0.1

        loss = loss_mvc + reg
        if (t+1) % 50 == 0:
            print("t {}/{} mvc_loss: {} reg: {}".format(t,
                                                        opt.nepochs, loss_mvc.item(), reg.item()))

        if loss_mvc.item() < 5e-6:
            break
        loss.backward()
        optimizer.step()
        scheduler.step()

    return cage_v, cage_f