示例#1
0
def test(net=None, save_subdir="test"):
    opt.phase = "test"
    dataset = build_dataset(opt)

    if opt.dim == 3:
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    if net is None:
        # network
        net = networks.NetworkFull(
            opt,
            dim=opt.dim,
            bottleneck_size=opt.bottleneck_size,
            template_vertices=cage_V_t,
            template_faces=init_cage_Fs[-1],
        ).cuda()
        net.eval()
        load_network(net, opt.ckpt)
    else:
        net.eval()

    print(net)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=tolerating_collate,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    test_output_dir = os.path.join(opt.log_dir, save_subdir)
    os.makedirs(test_output_dir, exist_ok=True)
    with open(os.path.join(test_output_dir, "eval.txt"), "w") as f:
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                data = dataset.uncollate(data)

                ############# blending ############
                # sample 4 different alpha
                if opt.blend_style:
                    num_alpha = 4
                    blend_alpha = torch.linspace(
                        0, 1, steps=num_alpha,
                        dtype=torch.float32).cuda().reshape(num_alpha, 1)
                    data["source_shape"] = data["source_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                    data["target_shape"] = data["target_shape"].expand(
                        num_alpha, -1, -1).contiguous()
                else:
                    blend_alpha = 1.0

                data["alpha"] = blend_alpha

                ###################################
                source_shape_t = data["source_shape"].transpose(
                    1, 2).contiguous().detach()
                target_shape_t = data["target_shape"].transpose(
                    1, 2).contiguous().detach()

                outputs = net(source_shape_t, target_shape_t, blend_alpha)
                deformed = outputs["deformed"]

                ####################### evaluation ########################
                s_filename = os.path.splitext(data["source_file"][0])[0]
                t_filename = os.path.splitext(data["target_file"][0])[0]

                log_str = "{}/{} {}-{} ".format(i, len(dataloader), s_filename,
                                                t_filename)
                print(log_str)
                f.write(log_str + "\n")

                ###################### outputs ############################
                for b in range(deformed.shape[0]):
                    if "source_mesh" in data and data[
                            "source_mesh"] is not None:
                        if isinstance(data["source_mesh"][0], str):
                            source_mesh = om.read_polymesh(
                                data["source_mesh"][0]).points().copy()
                            source_mesh = dataset.normalize(
                                source_mesh, opt.isV2)
                            source_mesh = torch.from_numpy(
                                source_mesh.astype(
                                    np.float32)).unsqueeze(0).cuda()
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], source_mesh)
                        else:
                            deformed = deform_with_MVC(
                                outputs["cage"][b:b + 1],
                                outputs["new_cage"][b:b + 1],
                                outputs["cage_face"], data["source_mesh"])

                    deformed[b] = center_bounding_box(deformed[b])[0]
                    if data["source_face"] is not None and data[
                            "source_mesh"] is not None:
                        source_mesh = data["source_mesh"][0].detach().cpu()
                        source_mesh = center_bounding_box(source_mesh)[0]
                        source_face = data["source_face"][0].detach().cpu()
                        tosave = pymesh.form_mesh(vertices=source_mesh,
                                                  faces=source_face)
                        pymesh.save_mesh(os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-Sa.obj".format(s_filename, t_filename)),
                                         tosave,
                                         use_float=True)
                        tosave = pymesh.form_mesh(
                            vertices=deformed[0].detach().cpu(),
                            faces=source_face)
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)),
                            tosave,
                            use_float=True,
                        )
                    elif data["source_face"] is None and isinstance(
                            data["source_mesh"][0], str):
                        orig_file_path = data["source_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = source_mesh[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.obj".format(s_filename, t_filename)),
                            mesh)
                        points_arr[:] = deformed[0].detach().cpu().numpy()
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.obj".format(
                                    s_filename, t_filename, b)), mesh)
                    else:
                        # save to "pts" for rendering
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sa.pts".format(s_filename, t_filename)),
                            data["source_shape"][b].detach().cpu())
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sab-{}.pts".format(
                                    s_filename, t_filename, b)),
                            deformed[0].detach().cpu())

                    if data["target_face"] is not None and data[
                            "target_mesh"] is not None:
                        data["target_mesh"][0] = center_bounding_box(
                            data["target_mesh"][0])[0]
                        tosave = pymesh.form_mesh(
                            vertices=data["target_mesh"][0].detach().cpu(),
                            faces=data["target_face"][0].detach().cpu())
                        pymesh.save_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            tosave,
                            use_float=True,
                        )
                    elif data["target_face"] is None and isinstance(
                            data["target_mesh"][0], str):
                        orig_file_path = data["target_mesh"][0]
                        mesh = om.read_polymesh(orig_file_path)
                        points_arr = mesh.points()
                        points_arr[:] = dataset.normalize(
                            points_arr.copy(), opt.isV2)
                        om.write_mesh(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.obj".format(s_filename, t_filename)),
                            mesh)
                    else:
                        save_pts(
                            os.path.join(
                                opt.log_dir, save_subdir,
                                "{}-{}-Sb.pts".format(s_filename, t_filename)),
                            data["target_shape"][0].detach().cpu())

                    outputs["cage"][b] = center_bounding_box(
                        outputs["cage"][b])[0]
                    outputs["new_cage"][b] = center_bounding_box(
                        outputs["new_cage"][b])[0]
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage1-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)
                    pymesh.save_mesh_raw(
                        os.path.join(
                            opt.log_dir, save_subdir,
                            "{}-{}-cage2-{}.ply".format(
                                s_filename, t_filename, b)),
                        outputs["new_cage"][b].detach().cpu(),
                        outputs["cage_face"][0].detach().cpu(),
                        binary=True)

    dataset.render_result(test_output_dir)
示例#2
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=tolerating_collate,
        num_workers=2,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    if opt.dim == 3:
        # cage (1,N,3)
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        cage_edge_points_list = []
        cage_edges_list = []
        for F in init_cage_Fs:
            mesh = Mesh(vertices=init_cage_V[0], faces=F[0])
            build_gemm(mesh, F[0])
            cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
            cage_edge_points_list.append(cage_edge_points)
            cage_edges_list = [edge_vertex_indices(F[0])]
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    # network
    net = networks.NetworkFull(
        opt,
        dim=opt.dim,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_V_t,
        template_faces=init_cage_Fs[-1],
    ).cuda()

    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)

    all_losses = losses.AllLosses(opt)
    # optimizer
    optimizer = torch.optim.Adam([{
        "params": net.encoder.parameters()
    }, {
        "params": net.nd_decoder.parameters()
    }, {
        "params": net.merger.parameters()
    }],
                                 lr=opt.lr)

    if opt.full_net:
        optimizer.add_param_group({
            'params': net.nc_decoder.parameters(),
            'lr': 0.1 * opt.lr
        })
    if opt.optimize_template:
        optimizer.add_param_group({
            'params': net.template_vertices,
            'lr': opt.lr
        })

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                int(opt.nepochs * 0.4),
                                                gamma=0.1,
                                                last_epoch=-1)

    # train
    net.train()
    start_epoch = 0
    t = 0

    steps_C = 20
    steps_D = 20

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "networks.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "option.py"),
                 opt.log_dir)
    print(net)

    log_file = open(os.path.join(opt.log_dir, "training_log.txt"), "a")
    log_file.write(str(net) + "\n")

    log_interval = max(len(dataloader) // 5, 50)
    save_interval = max(opt.nepochs // 10, 1)
    running_avg_loss = -1

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for t_epoch, data in enumerate(dataloader):
                warming_up = epoch < opt.warmup_epochs
                progress = t_epoch / len(dataloader) + epoch
                optimize_C = (t % (steps_C + steps_D)) > steps_D

                ############# get data ###########
                data = dataset.uncollate(data)
                data = crisscross_input(data)
                if opt.dim == 3:
                    data["cage_edge_points"] = cage_edge_points_list[-1]
                    data["cage_edges"] = cage_edges_list[-1]
                source_shape, target_shape = data["source_shape"], data[
                    "target_shape"]

                ############# blending ############
                if opt.blend_style:
                    blend_alpha = torch.rand(
                        (source_shape.shape[0], 1),
                        dtype=torch.float32).to(device=source_shape.device)
                else:
                    blend_alpha = 1.0
                data["alpha"] = blend_alpha

                ############# run network ###########
                optimizer.zero_grad()
                # optimizer_C.zero_grad()
                # optimizer_D.zero_grad()
                source_shape_t = source_shape.transpose(1, 2)
                target_shape_t = target_shape.transpose(1, 2)
                outputs = net(source_shape_t, target_shape_t, data["alpha"])

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 5 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        warming_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 5*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()
                if epoch < opt.warmup_epochs:
                    try:
                        net.nc_decoder.zero_grad()
                        net.encoder.zero_grad()
                    except AttributeError:
                        net.template_vertices.grad.zero_()

                if opt.alternate_cd:
                    optimize_C = (epoch > opt.warmup_epochs) and (
                        epoch % (opt.c_epoch + opt.d_epoch)) > opt.d_epoch
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                    else:
                        try:
                            net.encoder.zero_grad()
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                clamp_gradient(net, 0.1)
                optimizer.step()

                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()
            if opt.eval:
                try:
                    test(net=net, save_subdir="epoch_{}".format(epoch))
                except Exception as e:
                    traceback.print_exc(file=sys.stdout)
                    logger.warn("Failed to run test", str(e))

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test(net=net)