コード例 #1
0
    def forward(self, source_shape, target_shape, alpha=1.0):
        """
        source_shape (B,3,N)
        target_shape (B,3,M)
        init_cage    (B,3,P)
        return:
            deformed (B,3,N)
            cage     (B,3,P)
            new_cage (B,3,P)
            weights  (B,N,P)
        """
        B, _, N = source_shape.shape
        _, M, N = target_shape.shape
        _, _, P = self.template_vertices.shape
        ############ Encoder #############
        input_shapes = torch.cat([source_shape, target_shape], dim=0)
        shape_code = self.encoder(input_shapes)
        shape_code.unsqueeze_(-1)
        s_code, t_code = torch.split(shape_code, B, dim=0)

        # ############ Cage ################
        cage = self.template_vertices.view(1, self.dim, -1).expand(B, -1, -1)
        if self.opt.full_net:
            cage = self.nc_decoder(s_code, cage)
            # cage.register_hook(save_grad("d_cage"))

        ########### Deform ##########
        # first fold use global feature
        target_code = torch.cat([s_code, t_code], dim=1)
        target_code = self.merger(target_code)
        new_cage = self.nd_decoder(target_code, cage)

        ########### MVC ##############
        if self.dim == 3:
            cage = cage.transpose(1, 2).contiguous()
            new_cage = new_cage.transpose(1, 2).contiguous()
            deformed_shapes, weights, weights_unnormed = deform_with_MVC(
                cage,
                new_cage,
                self.template_faces.expand(B, -1, -1),
                source_shape.transpose(1, 2).contiguous(),
                verbose=True)
        elif self.dim == 2:
            weights, weights_unnormed = mean_value_coordinates(source_shape,
                                                               cage,
                                                               verbose=True)
            deformed_shapes = torch.sum(weights.unsqueeze(1) *
                                        new_cage.unsqueeze(-1),
                                        dim=2).transpose(1, 2).contiguous()
            cage = cage.transpose(1, 2)
            new_cage = new_cage.transpose(1, 2)

        return {
            "cage": cage,
            "new_cage": new_cage,
            "deformed": deformed_shapes,
            "cage_face": self.template_faces,
            "weight": weights,
            "weight_unnormed": weights_unnormed
        }
コード例 #2
0
ファイル: network2.py プロジェクト: star-cold/deep_cage
    def forward(self, source_shape, target_shape, alpha=1.0):
        """
        source_shape (B,3,N)
        target_shape (B,3,M)
        return:
            deformed (B,3,N)
            cage     (B,3,P)
            new_cage (B,3,P)
            weights  (B,N,P)
        """
        B, _, N = source_shape.shape
        _, M, N = target_shape.shape
        _, _, P = self.template.shape

        ############ Encoder #############
        input_shapes = torch.cat([source_shape, target_shape], dim=0)
        shape_code = self.encoder(input_shapes)
        shape_code.unsqueeze_(-1)
        s_code, t_code = torch.split(shape_code, B, dim=0)

        ############ sample on the spherical template  ##############
        sphere_sample = torch.from_numpy(random_sphere(B, N - P)).to(
            device=source_shape.device, dtype=source_shape.dtype)
        sphere_sample = sphere_sample.transpose(1, 2).contiguous()
        template = self.template.view(1, 2, -1).expand(B, -1, -1)
        template_v = angle_to_xyz(template, self.r)
        template_v += torch.randn_like(template_v) * 0.01
        sphere_sample = torch.cat([template_v, sphere_sample], dim=-1)

        ########### Cage ############
        cage_surface, cage_feat = self.nc_decoder(s_code, sphere_sample)
        cage = cage_surface[:, :, :P]

        ########### Deform ##########
        if not self.D_use_C_global_code:
            s_code = cage_feat
        alpha.unsqueeze_(-1)
        t_code = alpha * t_code + (1 - alpha) * s_code
        st_code = torch.cat([s_code, t_code], dim=1)
        # ts_code = torch.cat([t_code, s_code], dim=1)
        # target_code = torch.cat([st_code, ts_code], dim=0)
        target_code = self.merger(st_code)

        deformed_surface = self.nd_decoder(target_code, cage_surface)
        new_cage = deformed_surface[:, :, :P]

        ########### MVC ##############
        deformed_shapes, weights, weights_unnormed = deform_with_MVC(
            cage.transpose(1, 2).contiguous(),
            new_cage.transpose(1, 2).contiguous(),
            self.template_faces.expand(B, -1, -1),
            source_shape.transpose(1, 2).contiguous(),
            verbose=True)

        ########### prepare for output ##############
        cage = cage.transpose(1, 2)
        new_cage = new_cage.transpose(1, 2)
        cage_surface = cage_surface.transpose(1, 2)
        deformed_surface = deformed_surface.transpose(1, 2)
        # cage, t_cage = torch.split(cage, B, dim=0)
        # s_new_cage, t_new_cage = torch.split(new_cage, B, dim=0)
        # cage_surface, t_surface = torch.split(cage_surface, B, dim=0)
        # s_deformed_surface, t_deformed_surface = torch.split(deformed_surface, B, dim=0)
        # deformed, t_deformed = torch.split(deformed_shapes, B, dim=0)

        return {
            "cage": cage,
            "cage_surface": cage_surface,
            # "t_surface": t_surface,
            "new_cage": new_cage,
            # "t_new_cage": t_new_cage,
            "new_cage_surface": deformed_surface,
            # "t_deformed_surface": t_deformed_surface,
            "deformed": deformed_shapes,
            # "t_deformed": t_deformed,
            "cage_face": self.template_faces,
            "weight": weights,
            "weight_unnormed": weights_unnormed
        }
コード例 #3
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))
    source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
    source_face = dataset.mesh_face.unsqueeze(0)
    cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
    cage_face = dataset.cage_face.unsqueeze(0)
    mesh = Mesh(vertices=cage_shape[0], faces=cage_face[0])
    build_gemm(mesh, cage_face[0])
    cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
    cage_edges = edge_vertex_indices(cage_face[0])

    # network
    net = networks.FixedSourceDeformer(
        opt,
        3,
        opt.num_point,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_shape.transpose(1, 2),
        template_faces=cage_face,
        source_vertices=source_shape.transpose(1, 2),
        source_faces=source_face).cuda()
    print(net)
    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)
    net.train()

    all_losses = losses.AllLosses(opt)

    # optimizer
    optimizer = torch.optim.Adam([{
        'params': net.nd_decoder.parameters()
    }, {
        "params": net.encoder.parameters()
    }],
                                 lr=opt.lr)

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "network2.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_Sa.ply".format(0)),
        net.source_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.source_faces[0].detach().cpu())
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_template.ply".format(0)),
        net.template_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.template_faces[0].detach().cpu())

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                max(int(opt.nepochs * 0.75),
                                                    1),
                                                gamma=0.5,
                                                last_epoch=-1)

    # train
    net.train()
    t = 0
    start_epoch = 0
    warmed_up = False
    mvc_weight = opt.mvc_weight
    opt.mvc_weight = 0

    os.makedirs(opt.log_dir, exist_ok=True)
    running_avg_loss = -1
    log_file = open(os.path.join(opt.log_dir, "loss_log.txt"), "a")
    log_interval = min(max(len(dataloader) // 5, 50), 200)
    save_interval = max(opt.nepochs // 10, 1)

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for epoch_t, data in enumerate(dataloader):
                progress = epoch_t / len(dataloader) + epoch
                warming_up = progress < opt.warmup_epochs
                if (opt.deform_template or opt.optimize_template) and (
                        progress >= opt.warmup_epochs) and (not warmed_up):
                    if opt.deform_template:
                        optimizer.add_param_group({
                            'params':
                            net.nc_decoder.parameters(),
                            'lr':
                            0.1 * opt.lr
                        })
                    if opt.optimize_template:
                        optimizer.add_param_group({
                            'params': net.template_vertices,
                            'lr': 0.1 * opt.lr
                        })
                    warmed_up = True
                    # start to compute mvc weight
                    opt.mvc_weight = mvc_weight
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="warmed_up")

                ############# get data ###########
                data = dataset.uncollate(data)
                data["cage_edge_points"] = cage_edge_points
                data["cage_edges"] = cage_edges
                data["source_shape"] = net.source_vertices.detach()
                data["source_face"] = net.source_faces.detach()

                ############# run network ###########
                optimizer.zero_grad()
                target_shape_t = data["target_shape"].transpose(1, 2)
                sample_idx = None

                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                    if data["source_normals"] is not None:
                        data["source_normals"] = torch.gather(
                            data["source_normals"], 1,
                            sample_idx.unsqueeze(-1).expand(-1, -1, 3))

                outputs = net(target_shape_t, sample_idx)
                if opt.sfnormal_weight > 0 and ("source_mesh" in data
                                                and "source_mesh" is not None):
                    if outputs["deformed"].shape[1] == data[
                            "source_mesh"].shape[1]:
                        outputs["deformed_hr"] = outputs["deformed"]
                    else:
                        outputs["deformed_hr"] = deform_with_MVC(
                            outputs["cage"].expand(
                                data["source_mesh"].shape[0], -1, -1).detach(),
                            outputs["new_cage"], outputs["cage_face"].expand(
                                data["source_mesh"].shape[0], -1,
                                -1), data["source_mesh"])
                data["source_shape"] = outputs["source_shape"]

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 10 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        not warmed_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)
                    # save_ply(data["target_shape"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(t)))
                    # save_ply(outputs["deformed"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sab.ply".format(t)))
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage1.ply".format(t)),
                    #               outputs["cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage2.ply".format(t)),
                    #               outputs["new_cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 10*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()

                if opt.alternate_cd:
                    optimize_C = (progress > opt.warmup_epochs) and (
                        t % (opt.c_step + opt.d_step)) > opt.d_step
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                        net.encoder.zero_grad()
                    else:
                        try:
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                # clamp_gradient_norm(net, 1)
                optimizer.step()
                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test_all(net=net)
コード例 #4
0
def test(net=None, subdir="test"):
    opt.phase = "test"
    if isinstance(opt.target_model, str):
        opt.target_model = [opt.target_model]

    if net is None:
        states = torch.load(opt.ckpt)
        if "states" in states:
            states = states["states"]
        if opt.template:
            cage_shape, cage_face = read_trimesh(opt.template)
            cage_shape = torch.from_numpy(
                cage_shape[:, :3]).unsqueeze(0).float()
            cage_face = torch.from_numpy(cage_face).unsqueeze(0).long()
            states["template_vertices"] = cage_shape.transpose(1, 2)
            states["template_faces"] = cage_face

        if opt.source_model:
            source_shape, source_face = read_trimesh(opt.source_model)
            source_shape = torch.from_numpy(
                source_shape[:, :3]).unsqueeze(0).float()
            source_face = torch.from_numpy(source_face).unsqueeze(0).long()
            states["source_vertices"] = source_shape.transpose(1, 2)
            states["source_faces"] = source_shape

        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=states["template_vertices"],
            template_faces=states["template_faces"],
            source_vertices=states["source_vertices"],
            source_faces=states["source_faces"]).cuda()

        load_network(net, states)
        net = net.cuda()
        net.eval()
    else:
        net.eval()

    print(net)

    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with torch.no_grad():
        for target_model in opt.target_model:
            assert (os.path.isfile(target_model))
            target_face = None
            target_shape, target_face = read_trimesh(target_model)
            # target_shape = read_ply(target_model)[:,:3]
            # target_shape, _, scale = normalize_to_box(target_shape)
            # normalize acording to height y axis
            # target_shape = target_shape/2*1.7
            target_shape = torch.from_numpy(
                target_shape[:, :3]).cuda().float().unsqueeze(0)
            if target_face is None:
                target_face = net.source_faces
            else:
                target_face = torch.from_numpy(
                    target_face).cuda().long().unsqueeze(0)
            t_filename = os.path.splitext(os.path.basename(target_model))[0]

            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()

            # furthest sampling
            target_shape_sampled = furthest_point_sample(
                target_shape, net.source_vertices.shape[2], NCHW=False)[1]
            # target_shape_sampled = (target_shape[:, np.random.permutation(target_shape.shape[1]), :]).contiguous()
            outputs = net(target_shape_sampled.transpose(1, 2),
                          None,
                          cage_only=True)
            # deformed = outputs["deformed"]

            deformed = deform_with_MVC(
                outputs["cage"], outputs["new_cage"],
                outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                            -1), source_mesh)

            b = 0

            save_ply(
                target_shape_sampled[b].cpu().numpy(),
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.pts".format(t_filename)))
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sa.ply".format(t_filename)),
                source_mesh[0].detach().cpu(), source_face[0].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sb.ply".format(t_filename)),
                target_shape[b].detach().cpu(), target_face[b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-Sab.ply".format(t_filename)),
                deformed[b].detach().cpu(), source_face[b].detach().cpu())

            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage1.ply".format(t_filename)),
                outputs["cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())
            pymesh.save_mesh_raw(
                os.path.join(opt.log_dir, subdir,
                             "template-{}-cage2.ply".format(t_filename)),
                outputs["new_cage"][b].detach().cpu(),
                outputs["cage_face"][b].detach().cpu())

    PairedSurreal.render_result(test_output_dir)
コード例 #5
0
def test_all(net=None, subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if net is None:
        source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
        source_face = dataset.mesh_face.unsqueeze(0)
        cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
        cage_face = dataset.cage_face.unsqueeze(0)
        net = networks.FixedSourceDeformer(
            opt,
            3,
            opt.num_point,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_shape.transpose(1, 2),
            template_faces=cage_face,
            source_vertices=source_shape.transpose(1, 2),
            source_faces=source_face).cuda()

        load_network(net, opt.ckpt)
        net.eval()
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=3,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    chamfer_distance = losses.LabeledChamferDistance(beta=0, gamma=1)
    mse_distance = torch.nn.MSELoss()
    avg_CD = 0
    avg_EMD = 0
    test_output_dir = os.path.join(opt.log_dir, subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            source_mesh = net.source_vertices.transpose(1, 2).detach()
            source_face = net.source_faces.detach()
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                target_shape, target_filename = data["target_shape"], data[
                    "target_file"]

                sample_idx = None
                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                outputs = net(target_shape.transpose(1, 2), sample_idx)
                deformed = outputs["deformed"]

                deformed = deform_with_MVC(
                    outputs["cage"], outputs["new_cage"],
                    outputs["cage_face"].expand(outputs["cage"].shape[0], -1,
                                                -1), source_mesh)

                for b in range(outputs["deformed"].shape[0]):
                    t_filename = os.path.splitext(target_filename[b])[0]
                    target_shape_np = target_shape.detach().cpu()[b].numpy()
                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sa.ply".format(t_filename)),
                            source_mesh[0].detach().cpu(),
                            source_face[0].detach().cpu())
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sb.ply".format(t_filename)),
                            data["target_mesh"][b].detach().cpu(),
                            data["target_face"][b].detach().cpu())
                        pymesh.save_mesh_raw(
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sab.ply".format(t_filename)),
                            deformed[b].detach().cpu(),
                            source_face[b].detach().cpu())
                    else:
                        save_ply(
                            source_mesh[0].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sa.ply".format(t_filename)))
                        save_ply(
                            target_shape[b].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sb.ply".format(t_filename)),
                            normals=data["target_normals"][b].detach().cpu())
                        save_ply(
                            deformed[b].detach().cpu(),
                            os.path.join(
                                opt.log_dir, subdir,
                                "template-{}-Sab.ply".format(t_filename)),
                            normals=data["target_normals"][b].detach().cpu())

                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, subdir,
                            "template-{}-cage1.ply".format(t_filename)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][b].detach().cpu())
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, subdir,
                            "template-{}-cage2.ply".format(t_filename)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][b].detach().cpu())

                    log_str = "{}/{} {}".format(i, len(dataloader), t_filename)
                    print(log_str)
                    f.write(log_str + "\n")

    dataset.render_result(test_output_dir)
コード例 #6
0
def test(net=None, save_subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if opt.dim == 3:
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    if net is None:
        # network
        net = networks.NetworkFull(
            opt,
            dim=opt.dim,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_V_t,
            template_faces=init_cage_Fs[-1],
        ).cuda()
        net.eval()
        load_network(net, opt.ckpt)
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=tolerating_collate,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    test_output_dir = os.path.join(opt.log_dir, save_subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                ############# blending ############
                # sample 4 different alpha
                if opt.blend_style:
                    num_alpha = 4
                    blend_alpha = torch.linspace(
                        0, 1, steps=num_alpha,
                        dtype=torch.float32).cuda().reshape(num_alpha, 1)
                    data["source_shape"] = data["source_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                    data["target_shape"] = data["target_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                else:
                    blend_alpha = 1.0

                data["alpha"] = blend_alpha

                ###################################
                source_shape_t = data["source_shape"].transpose(
                    1, 2).contiguous().detach()
                target_shape_t = data["target_shape"].transpose(
                    1, 2).contiguous().detach()

                outputs = net(source_shape_t, target_shape_t, blend_alpha)
                deformed = outputs["deformed"]

                ####################### evaluation ########################
                s_filename = os.path.splitext(data["source_file"][0])[0]
                t_filename = os.path.splitext(data["target_file"][0])[0]

                log_str = "{}/{} {}-{} ".format(i, len(dataloader), s_filename,
                                                t_filename)
                print(log_str)
                f.write(log_str + "\n")

                ###################### outputs ############################
                for b in range(deformed.shape[0]):
                    if "source_mesh" in data and data[
                            "source_mesh"] is not None:
                        if isinstance(data["source_mesh"][0], str):
                            source_mesh = om.read_polymesh(
                                data["source_mesh"][0]).points().copy()
                            source_mesh = dataset.normalize(
                                source_mesh, opt.isV2)
                            source_mesh = torch.from_numpy(
                                source_mesh.astype(
                                    np.float32)).unsqueeze(0).cuda()
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], source_mesh)
                        else:
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], data["source_mesh"])

                    deformed[b] = center_bounding_box(deformed[b])[0]
                    if data["source_face"] is not None and data[
                            "source_mesh"] is not None:
                        source_mesh = data["source_mesh"][0].detach().cpu()
                        source_mesh = center_bounding_box(source_mesh)[0]
                        source_face = data["source_face"][0].detach().cpu()
                        tosave = pymesh.form_mesh(vertices=source_mesh,
                                                  faces=source_face)
                        pymesh.save_mesh(os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-Sa.obj".format(s_filename, t_filename)),
                                         tosave,
                                         use_float=True)
                        tosave = pymesh.form_mesh(
                            vertices=deformed[0].detach().cpu(),
                            faces=source_face)
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)),
                            tosave,
                            use_float=True,
                        )
                    elif data["source_face"] is None and isinstance(
                            data["source_mesh"][0], str):
                        orig_file_path = data["source_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = source_mesh[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.obj".format(s_filename, t_filename)),
                            mesh)
                        points_arr[:] = deformed[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)), mesh)
                    else:
                        # save to "pts" for rendering
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.pts".format(s_filename, t_filename)),
                            data["source_shape"][b].detach().cpu())
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.pts".format(
                                    s_filename, t_filename, b)),
                            deformed[0].detach().cpu())

                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        data["target_mesh"][0] = center_bounding_box(
                            data["target_mesh"][0])[0]
                        tosave = pymesh.form_mesh(
                            vertices=data["target_mesh"][0].detach().cpu(),
                            faces=data["target_face"][0].detach().cpu())
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            tosave,
                            use_float=True,
                        )
                    elif data["target_face"] is None and isinstance(
                            data["target_mesh"][0], str):
                        orig_file_path = data["target_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = dataset.normalize(
                            points_arr.copy(), opt.isV2)
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            mesh)
                    else:
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.pts".format(s_filename, t_filename)),
                            data["target_shape"][0].detach().cpu())

                    outputs["cage"][b] = center_bounding_box(
                        outputs["cage"][b])[0]
                    outputs["new_cage"][b] = center_bounding_box(
                        outputs["new_cage"][b])[0]
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage1-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage2-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)

    dataset.render_result(test_output_dir)
コード例 #7
0
ファイル: optimize_cage.py プロジェクト: star-cold/deep_cage
def test_all(opt, new_cage_shape):
    opt.phase = "test"
    opt.target_model = None
    print(opt.model)

    if opt.is_poly:
        source_mesh = om.read_polymesh(opt.model)
    else:
        source_mesh = om.read_trimesh(opt.model)

    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False,
                                             collate_fn=tolerating_collate,
                                             num_workers=0, worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] + id))

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    # states["template_vertices"] = new_cage_shape.transpose(1, 2)
    # states["source_vertices"] = new_source.transpose(1,2)
    # states["source_faces"] = new_source_face
    # new_source_face = states["source_faces"]

    om.write_mesh(os.path.join(opt.log_dir, opt.subdir,
                               "template-Sa.ply"), source_mesh)

    net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size,
                                       template_vertices=states["template_vertices"], template_faces=states["template_faces"].cuda(),
                                       source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda()
    load_network(net, states)

    source_points = torch.from_numpy(
        source_mesh.points().copy()).float().cuda().unsqueeze(0)
    with torch.no_grad():
        # source_face = net.source_faces.detach()
        for i, data in enumerate(dataloader):
            data = dataset.uncollate(data)

            target_shape, target_filename = data["target_shape"], data["target_file"]
            logger.info("", data["target_file"][0])

            sample_idx = None
            if "sample_idx" in data:
                sample_idx = data["sample_idx"]

            outputs = net(target_shape.transpose(1, 2), cage_only=True)
            if opt.d_residual:
                cage_offset = outputs["new_cage"]-outputs["cage"]
                outputs["cage"] = new_cage_shape
                outputs["new_cage"] = new_cage_shape+cage_offset

            deformed = deform_with_MVC(outputs["cage"], outputs["new_cage"], outputs["cage_face"].expand(
                outputs["cage"].shape[0], -1, -1), source_points)

            for b in range(deformed.shape[0]):
                t_filename = os.path.splitext(target_filename[b])[0]
                source_mesh_arr = source_mesh.points()
                source_mesh_arr[:] = deformed[0].cpu().detach().numpy()
                om.write_mesh(os.path.join(
                    opt.log_dir, opt.subdir, "template-{}-Sab.obj".format(t_filename)), source_mesh)
                # if data["target_face"] is not None and data["target_mesh"] is not None:
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)),
                #             source_mesh[0].detach().cpu(), source_face[b].detach().cpu())
                pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sb.ply".format(t_filename)),
                                     data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu())
                # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)),
                #             deformed[b].detach().cpu(), source_face[b].detach().cpu())

                # else:
                #     save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename)))
                #     save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())
                #     save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)),
                #                 normals=data["target_normals"][b].detach().cpu())

                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage1.ply".format(t_filename)),
                    outputs["cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )
                pymesh.save_mesh_raw(
                    os.path.join(opt.log_dir, opt.subdir, "template-{}-cage2.ply".format(t_filename)),
                    outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
                                   )

            # if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]:
            #     deformed = optimize_lap(opt, source_mesh, deformed, source_face)
            #     for b in range(deformed.shape[0]):
            #         pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)),
            #                                 deformed[b].detach().cpu(), source_face[b].detach().cpu())

            if i % 20 == 0:
                logger.success("[{}/{}] Done".format(i, len(dataloader)))

    dataset.render_result(os.path.join(opt.log_dir, opt.subdir))