예제 #1
0
def test_one(opt, cage_shape, new_source, new_source_face, new_target, new_target_face):
    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]

    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.ply"),
                         states["source_vertices"][0].transpose(
                             0, 1).detach().cpu(),
                         states["source_faces"][0].detach().cpu())

    # states["template_vertices"] = cage_shape.transpose(1, 2)
    # states["source_vertices"] = new_source.transpose(1, 2)
    # states["source_faces"] = new_source_face

    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-Sa.ply"),
                         new_source[0].detach().cpu(), new_source_face[0].detach().cpu())
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-Sb.ply"),
                         new_target[0].detach().cpu(), new_target_face[0].detach().cpu())

    net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=512,
                                       template_vertices=cage_shape.transpose(1, 2), template_faces=states["template_faces"].cuda(),
                                       source_vertices=new_source.transpose(1, 2), source_faces=new_source_face).cuda()

    net.eval()
    load_network(net, states)

    outputs = net(new_target.transpose(1, 2).contiguous())
    deformed = outputs["deformed"]

    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-Sab.ply"),
                         deformed[0].detach().cpu(), new_target_face[0].detach().cpu())
예제 #2
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))
    source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
    source_face = dataset.mesh_face.unsqueeze(0)
    cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
    cage_face = dataset.cage_face.unsqueeze(0)
    mesh = Mesh(vertices=cage_shape[0], faces=cage_face[0])
    build_gemm(mesh, cage_face[0])
    cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
    cage_edges = edge_vertex_indices(cage_face[0])

    # network
    net = networks.FixedSourceDeformer(
        opt,
        3,
        opt.num_point,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_shape.transpose(1, 2),
        template_faces=cage_face,
        source_vertices=source_shape.transpose(1, 2),
        source_faces=source_face).cuda()
    print(net)
    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)
    net.train()

    all_losses = losses.AllLosses(opt)

    # optimizer
    optimizer = torch.optim.Adam([{
        'params': net.nd_decoder.parameters()
    }, {
        "params": net.encoder.parameters()
    }],
                                 lr=opt.lr)

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "network2.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_Sa.ply".format(0)),
        net.source_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.source_faces[0].detach().cpu())
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_template.ply".format(0)),
        net.template_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.template_faces[0].detach().cpu())

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                max(int(opt.nepochs * 0.75),
                                                    1),
                                                gamma=0.5,
                                                last_epoch=-1)

    # train
    net.train()
    t = 0
    start_epoch = 0
    warmed_up = False
    mvc_weight = opt.mvc_weight
    opt.mvc_weight = 0

    os.makedirs(opt.log_dir, exist_ok=True)
    running_avg_loss = -1
    log_file = open(os.path.join(opt.log_dir, "loss_log.txt"), "a")
    log_interval = min(max(len(dataloader) // 5, 50), 200)
    save_interval = max(opt.nepochs // 10, 1)

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for epoch_t, data in enumerate(dataloader):
                progress = epoch_t / len(dataloader) + epoch
                warming_up = progress < opt.warmup_epochs
                if (opt.deform_template or opt.optimize_template) and (
                        progress >= opt.warmup_epochs) and (not warmed_up):
                    if opt.deform_template:
                        optimizer.add_param_group({
                            'params':
                            net.nc_decoder.parameters(),
                            'lr':
                            0.1 * opt.lr
                        })
                    if opt.optimize_template:
                        optimizer.add_param_group({
                            'params': net.template_vertices,
                            'lr': 0.1 * opt.lr
                        })
                    warmed_up = True
                    # start to compute mvc weight
                    opt.mvc_weight = mvc_weight
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="warmed_up")

                ############# get data ###########
                data = dataset.uncollate(data)
                data["cage_edge_points"] = cage_edge_points
                data["cage_edges"] = cage_edges
                data["source_shape"] = net.source_vertices.detach()
                data["source_face"] = net.source_faces.detach()

                ############# run network ###########
                optimizer.zero_grad()
                target_shape_t = data["target_shape"].transpose(1, 2)
                sample_idx = None

                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                    if data["source_normals"] is not None:
                        data["source_normals"] = torch.gather(
                            data["source_normals"], 1,
                            sample_idx.unsqueeze(-1).expand(-1, -1, 3))

                outputs = net(target_shape_t, sample_idx)
                if opt.sfnormal_weight > 0 and ("source_mesh" in data
                                                and "source_mesh" is not None):
                    if outputs["deformed"].shape[1] == data[
                            "source_mesh"].shape[1]:
                        outputs["deformed_hr"] = outputs["deformed"]
                    else:
                        outputs["deformed_hr"] = deform_with_MVC(
                            outputs["cage"].expand(
                                data["source_mesh"].shape[0], -1, -1).detach(),
                            outputs["new_cage"], outputs["cage_face"].expand(
                                data["source_mesh"].shape[0], -1,
                                -1), data["source_mesh"])
                data["source_shape"] = outputs["source_shape"]

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 10 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        not warmed_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)
                    # save_ply(data["target_shape"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(t)))
                    # save_ply(outputs["deformed"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sab.ply".format(t)))
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage1.ply".format(t)),
                    #               outputs["cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage2.ply".format(t)),
                    #               outputs["new_cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 10*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()

                if opt.alternate_cd:
                    optimize_C = (progress > opt.warmup_epochs) and (
                        t % (opt.c_step + opt.d_step)) > opt.d_step
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                        net.encoder.zero_grad()
                    else:
                        try:
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                # clamp_gradient_norm(net, 1)
                optimizer.step()
                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test_all(net=net)
예제 #3
0
def test(net=None, subdir="test"):
    opt.phase = "test"
    if isinstance(opt.target_model, str):
        opt.target_model = [opt.target_model]

    if net is None:
        states = torch.load(opt.ckpt)
        if "states" in states:
            states = states["states"]
        if opt.template:
            cage_shape, cage_face = read_trimesh(opt.template)
            cage_shape = torch.from_numpy(
                cage_shape[:, :3]).unsqueeze(0).float()
            cage_face = torch.from_numpy(cage_face).unsqueeze(0).long()
            states["template_vertices"] = cage_shape.transpose(1, 2)
            states["template_faces"] = cage_face

        if opt.source_model:
            source_shape, source_face = read_trimesh(opt.source_model)
            source_shape = torch.from_numpy(
                source_shape[:, :3]).unsqueeze(0).float()
            source_face = torch.from_numpy(source_face).unsqueeze(0).long()
            states["source_vertices"] = source_shape.transpose(1, 2)
            states["source_faces"] = source_shape

        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=states["template_vertices"],
            template_faces=states["template_faces"],
            source_vertices=states["source_vertices"],
            source_faces=states["source_faces"]).cuda()

        load_network(net, states)
        net = net.cuda()
        net.eval()
    else:
        net.eval()

    print(net)

    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with torch.no_grad():
        for target_model in opt.target_model:
            assert (os.path.isfile(target_model))
            target_face = None
            target_shape, target_face = read_trimesh(target_model)
            # target_shape = read_ply(target_model)[:,:3]
            # target_shape, _, scale = normalize_to_box(target_shape)
            # normalize acording to height y axis
            # target_shape = target_shape/2*1.7
            target_shape = torch.from_numpy(
                target_shape[:, :3]).cuda().float().unsqueeze(0)
            if target_face is None:
                target_face = net.source_faces
            else:
                target_face = torch.from_numpy(
                    target_face).cuda().long().unsqueeze(0)
            t_filename = os.path.splitext(os.path.basename(target_model))[0]

            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()

            # furthest sampling
            target_shape_sampled = furthest_point_sample(
                target_shape, net.source_vertices.shape[2], NCHW=False)[1]
            # target_shape_sampled = (target_shape[:, np.random.permutation(target_shape.shape[1]), :]).contiguous()
            outputs = net(target_shape_sampled.transpose(1, 2),
                          None,
                          cage_only=True)
            # deformed = outputs["deformed"]

            deformed = deform_with_MVC(
                outputs["cage"], outputs["new_cage"],
                outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                            -1), source_mesh)

            b = 0

            save_ply(
                target_shape_sampled[b].cpu().numpy(),
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.pts".format(t_filename)))
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sa.ply".format(t_filename)),
                source_mesh[0].detach().cpu(), source_face[0].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.ply".format(t_filename)),
                target_shape[b].detach().cpu(), target_face[b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sab.ply".format(t_filename)),
                deformed[b].detach().cpu(), source_face[b].detach().cpu())

            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage1.ply".format(t_filename)),
                outputs["cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage2.ply".format(t_filename)),
                outputs["new_cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())

    PairedSurreal.render_result(test_output_dir)
예제 #4
0
def test_all(net=None, subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if net is None:
        source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
        source_face = dataset.mesh_face.unsqueeze(0)
        cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
        cage_face = dataset.cage_face.unsqueeze(0)
        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_shape.transpose(1, 2),
            template_faces=cage_face,
            source_vertices=source_shape.transpose(1, 2),
            source_faces=source_face).cuda()

        load_network(net, opt.ckpt)
        net.eval()
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=3,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    chamfer_distance = losses.LabeledChamferDistance(beta=0, gamma=1)
    mse_distance = torch.nn.MSELoss()
    avg_CD = 0
    avg_EMD = 0
    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                target_shape, target_filename = data["target_shape"], data[
                    "target_file"]

                sample_idx = None
                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                outputs = net(target_shape.transpose(1, 2), sample_idx)
                deformed = outputs["deformed"]

                deformed = deform_with_MVC(
                    outputs["cage"], outputs["new_cage"],
                    outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                                -1), source_mesh)

                for b in range(outputs["deformed"].shape[0]):
                    t_filename = os.path.splitext(target_filename[b])[0]
                    target_shape_np = target_shape.detach().cpu()[b].numpy()
                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sa.ply".format(t_filename)),
                            source_mesh[0].detach().cpu(),
                            source_face[0].detach().cpu())
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sb.ply".format(t_filename)),
                            data["target_mesh"][b].detach().cpu(),
                            data["target_face"][b].detach().cpu())
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sab.ply".format(t_filename)),
                            deformed[b].detach().cpu(),
                            source_face[b].detach().cpu())
                    else:
                        save_ply(
                            source_mesh[0].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sa.ply".format(t_filename)))
                        save_ply(
                            target_shape[b].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sb.ply".format(t_filename)),
                            normals=data["target_normals"][b].detach().cpu())
                        save_ply(
                            deformed[b].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sab.ply".format(t_filename)),
                            normals=data["target_normals"][b].detach().cpu())

                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, subdir,
                            "template-{}-cage1.ply".format(t_filename)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][b].detach().cpu())
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, subdir,
                            "template-{}-cage2.ply".format(t_filename)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][b].detach().cpu())

                    log_str = "{}/{} {}".format(i, len(dataloader), t_filename)
                    print(log_str)
                    f.write(log_str + "\n")

    dataset.render_result(test_output_dir)
예제 #5
0
def test(net=None, save_subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if opt.dim == 3:
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    if net is None:
        # network
        net = networks.NetworkFull(
            opt,
            dim=opt.dim,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_V_t,
            template_faces=init_cage_Fs[-1],
        ).cuda()
        net.eval()
        load_network(net, opt.ckpt)
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=tolerating_collate,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    test_output_dir = os.path.join(opt.log_dir, save_subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                ############# blending ############
                # sample 4 different alpha
                if opt.blend_style:
                    num_alpha = 4
                    blend_alpha = torch.linspace(
                        0, 1, steps=num_alpha,
                        dtype=torch.float32).cuda().reshape(num_alpha, 1)
                    data["source_shape"] = data["source_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                    data["target_shape"] = data["target_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                else:
                    blend_alpha = 1.0

                data["alpha"] = blend_alpha

                ###################################
                source_shape_t = data["source_shape"].transpose(
                    1, 2).contiguous().detach()
                target_shape_t = data["target_shape"].transpose(
                    1, 2).contiguous().detach()

                outputs = net(source_shape_t, target_shape_t, blend_alpha)
                deformed = outputs["deformed"]

                ####################### evaluation ########################
                s_filename = os.path.splitext(data["source_file"][0])[0]
                t_filename = os.path.splitext(data["target_file"][0])[0]

                log_str = "{}/{} {}-{} ".format(i, len(dataloader), s_filename,
                                                t_filename)
                print(log_str)
                f.write(log_str + "\n")

                ###################### outputs ############################
                for b in range(deformed.shape[0]):
                    if "source_mesh" in data and data[
                            "source_mesh"] is not None:
                        if isinstance(data["source_mesh"][0], str):
                            source_mesh = om.read_polymesh(
                                data["source_mesh"][0]).points().copy()
                            source_mesh = dataset.normalize(
                                source_mesh, opt.isV2)
                            source_mesh = torch.from_numpy(
                                source_mesh.astype(
                                    np.float32)).unsqueeze(0).cuda()
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], source_mesh)
                        else:
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], data["source_mesh"])

                    deformed[b] = center_bounding_box(deformed[b])[0]
                    if data["source_face"] is not None and data[
                            "source_mesh"] is not None:
                        source_mesh = data["source_mesh"][0].detach().cpu()
                        source_mesh = center_bounding_box(source_mesh)[0]
                        source_face = data["source_face"][0].detach().cpu()
                        tosave = pymesh.form_mesh(vertices=source_mesh,
                                                  faces=source_face)
                        pymesh.save_mesh(os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-Sa.obj".format(s_filename, t_filename)),
                                         tosave,
                                         use_float=True)
                        tosave = pymesh.form_mesh(
                            vertices=deformed[0].detach().cpu(),
                            faces=source_face)
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)),
                            tosave,
                            use_float=True,
                        )
                    elif data["source_face"] is None and isinstance(
                            data["source_mesh"][0], str):
                        orig_file_path = data["source_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = source_mesh[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.obj".format(s_filename, t_filename)),
                            mesh)
                        points_arr[:] = deformed[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)), mesh)
                    else:
                        # save to "pts" for rendering
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.pts".format(s_filename, t_filename)),
                            data["source_shape"][b].detach().cpu())
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.pts".format(
                                    s_filename, t_filename, b)),
                            deformed[0].detach().cpu())

                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        data["target_mesh"][0] = center_bounding_box(
                            data["target_mesh"][0])[0]
                        tosave = pymesh.form_mesh(
                            vertices=data["target_mesh"][0].detach().cpu(),
                            faces=data["target_face"][0].detach().cpu())
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            tosave,
                            use_float=True,
                        )
                    elif data["target_face"] is None and isinstance(
                            data["target_mesh"][0], str):
                        orig_file_path = data["target_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = dataset.normalize(
                            points_arr.copy(), opt.isV2)
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            mesh)
                    else:
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.pts".format(s_filename, t_filename)),
                            data["target_shape"][0].detach().cpu())

                    outputs["cage"][b] = center_bounding_box(
                        outputs["cage"][b])[0]
                    outputs["new_cage"][b] = center_bounding_box(
                        outputs["new_cage"][b])[0]
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage1-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage2-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)

    dataset.render_result(test_output_dir)
예제 #6
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=tolerating_collate,
        num_workers=2,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    if opt.dim == 3:
        # cage (1,N,3)
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        cage_edge_points_list = []
        cage_edges_list = []
        for F in init_cage_Fs:
            mesh = Mesh(vertices=init_cage_V[0], faces=F[0])
            build_gemm(mesh, F[0])
            cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
            cage_edge_points_list.append(cage_edge_points)
            cage_edges_list = [edge_vertex_indices(F[0])]
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    # network
    net = networks.NetworkFull(
        opt,
        dim=opt.dim,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_V_t,
        template_faces=init_cage_Fs[-1],
    ).cuda()

    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)

    all_losses = losses.AllLosses(opt)
    # optimizer
    optimizer = torch.optim.Adam([{
        "params": net.encoder.parameters()
    }, {
        "params": net.nd_decoder.parameters()
    }, {
        "params": net.merger.parameters()
    }],
                                 lr=opt.lr)

    if opt.full_net:
        optimizer.add_param_group({
            'params': net.nc_decoder.parameters(),
            'lr': 0.1 * opt.lr
        })
    if opt.optimize_template:
        optimizer.add_param_group({
            'params': net.template_vertices,
            'lr': opt.lr
        })

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                int(opt.nepochs * 0.4),
                                                gamma=0.1,
                                                last_epoch=-1)

    # train
    net.train()
    start_epoch = 0
    t = 0

    steps_C = 20
    steps_D = 20

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "networks.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "option.py"),
                 opt.log_dir)
    print(net)

    log_file = open(os.path.join(opt.log_dir, "training_log.txt"), "a")
    log_file.write(str(net) + "\n")

    log_interval = max(len(dataloader) // 5, 50)
    save_interval = max(opt.nepochs // 10, 1)
    running_avg_loss = -1

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for t_epoch, data in enumerate(dataloader):
                warming_up = epoch < opt.warmup_epochs
                progress = t_epoch / len(dataloader) + epoch
                optimize_C = (t % (steps_C + steps_D)) > steps_D

                ############# get data ###########
                data = dataset.uncollate(data)
                data = crisscross_input(data)
                if opt.dim == 3:
                    data["cage_edge_points"] = cage_edge_points_list[-1]
                    data["cage_edges"] = cage_edges_list[-1]
                source_shape, target_shape = data["source_shape"], data[
                    "target_shape"]

                ############# blending ############
                if opt.blend_style:
                    blend_alpha = torch.rand(
                        (source_shape.shape[0], 1),
                        dtype=torch.float32).to(device=source_shape.device)
                else:
                    blend_alpha = 1.0
                data["alpha"] = blend_alpha

                ############# run network ###########
                optimizer.zero_grad()
                # optimizer_C.zero_grad()
                # optimizer_D.zero_grad()
                source_shape_t = source_shape.transpose(1, 2)
                target_shape_t = target_shape.transpose(1, 2)
                outputs = net(source_shape_t, target_shape_t, data["alpha"])

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 5 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        warming_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 5*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()
                if epoch < opt.warmup_epochs:
                    try:
                        net.nc_decoder.zero_grad()
                        net.encoder.zero_grad()
                    except AttributeError:
                        net.template_vertices.grad.zero_()

                if opt.alternate_cd:
                    optimize_C = (epoch > opt.warmup_epochs) and (
                        epoch % (opt.c_epoch + opt.d_epoch)) > opt.d_epoch
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                    else:
                        try:
                            net.encoder.zero_grad()
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                clamp_gradient(net, 0.1)
                optimizer.step()

                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()
            if opt.eval:
                try:
                    test(net=net, save_subdir="epoch_{}".format(epoch))
                except Exception as e:
                    traceback.print_exc(file=sys.stdout)
                    logger.warn("Failed to run test", str(e))

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test(net=net)
예제 #7
0
def test_all(opt, new_cage_shape):
    opt.phase = "test"
    opt.target_model = None
    print(opt.model)

    if opt.is_poly:
        source_mesh = om.read_polymesh(opt.model)
    else:
        source_mesh = om.read_trimesh(opt.model)

    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False,
                                             collate_fn=tolerating_collate,
                                             num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id))

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    # states["template_vertices"] = new_cage_shape.transpose(1, 2)
    # states["source_vertices"] = new_source.transpose(1,2)
    # states["source_faces"] = new_source_face
    # new_source_face = states["source_faces"]

    om.write_mesh(os.path.join(opt.log_dir, opt.subdir,
                               "template-Sa.ply"), source_mesh)

    net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size,
                                       template_vertices=states["template_vertices"], template_faces=states["template_faces"].cuda(),
                                       source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda()
    load_network(net, states)

    source_points = torch.from_numpy(
        source_mesh.points().copy()).float().cuda().unsqueeze(0)
    with torch.no_grad():
        # source_face = net.source_faces.detach()
        for i, data in enumerate(dataloader):
            data = dataset.uncollate(data)

            target_shape, target_filename = data["target_shape"], data["target_file"]
            logger.info("", data["target_file"][0])

            sample_idx = None
            if "sample_idx" in data:
                sample_idx = data["sample_idx"]

            outputs = net(target_shape.transpose(1, 2), cage_only=True)
            if opt.d_residual:
                cage_offset = outputs["new_cage"]-outputs["cage"]
                outputs["cage"] = new_cage_shape
                outputs["new_cage"] = new_cage_shape+cage_offset

            deformed = deform_with_MVC(outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand(
                outputs["cage"].shape[0], -1, -1), source_points)

            for b in range(deformed.shape[0]):
                t_filename = os.path.splitext(target_filename[b])[0]
                source_mesh_arr = source_mesh.points()
                source_mesh_arr[:] = deformed[0].cpu().detach().numpy()
                om.write_mesh(os.path.join(
                    opt.log_dir, opt.subdir, "template-{}-Sab.obj".format(t_filename)), source_mesh)
                # if data["target_face"] is not None and data["target_mesh"] is not None:
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)),
                #             source_mesh[0].detach().cpu(), source_face[b].detach().cpu())
                pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sb.ply".format(t_filename)),
                                     data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu())
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)),
                #             deformed[b].detach().cpu(), source_face[b].detach().cpu())

                # else:
                #     save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename)))
                #     save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())
                #     save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())

                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage1.ply".format(t_filename)),
                    outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )
                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage2.ply".format(t_filename)),
                    outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )

            # if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]:
            #     deformed = optimize_lap(opt, source_mesh, deformed, source_face)
            #     for b in range(deformed.shape[0]):
            #         pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)),
            #                                 deformed[b].detach().cpu(), source_face[b].detach().cpu())

            if i % 20 == 0:
                logger.success("[{}/{}] Done".format(i, len(dataloader)))

    dataset.render_result(os.path.join(opt.log_dir, opt.subdir))