コード例 #1
0
ファイル: losses.py プロジェクト: star-cold/deep_cage
    def forward(self, ref_points, points, *args, **kwargs):
        B, N, C = ref_points.shape
        # TODO replace with ball query
        # (B,P,K,3), (B,P,K), (B,P,K)
        ref_grouped_points, ref_group_idx, ref_group_dist = faiss_knn(
            self.nn_size, ref_points, ref_points, NCHW=False)
        mask = (ref_group_dist < self.ball_size2)
        ref_grouped_points.masked_fill_(~mask.unsqueeze(-1), 0.0)
        # number of points inside the ball (B,P,1)
        nball = torch.sum(mask.to(torch.float), dim=-1, keepdim=True)
        ref_group_center = torch.sum(ref_grouped_points, dim=2,
                                     keepdim=True) / nball.unsqueeze(-1)
        # B,P,K,3
        ref_points = ref_grouped_points - ref_group_center
        ref_allpoints = ref_points.view(-1, self.nn_size, C).contiguous()
        U_ref, S_ref, V_ref = batch_svd(ref_allpoints)
        ref_cond = S_ref[:, 0] / (S_ref[:, -1] + S_ref[:, 0])
        ref_cond = ref_cond.view(B, N).contiguous()

        # grouped_points, group_idx, _ = faiss_knn(self.nn_size, points, points, NCHW=False)
        grouped_points = torch.gather(
            points.unsqueeze(1).expand(-1, N, -1, -1), 2,
            ref_group_idx.unsqueeze(-1).expand(-1, -1, -1, C))
        grouped_points.masked_fill(~mask.unsqueeze(-1), 0.0)
        group_center = torch.sum(grouped_points, dim=2,
                                 keepdim=True) / nball.unsqueeze(-1)
        points = grouped_points - group_center
        allpoints = points.view(-1, self.nn_size, C).contiguous()
        # S (BN, k)
        U, S, V = batch_svd(allpoints)
        cond = S[:, 0] / (S[:, -1] + S[:, 0])
        cond = cond.view(B, N).contiguous()

        return self.metric(cond, ref_cond)
コード例 #2
0
ファイル: losses.py プロジェクト: star-cold/deep_cage
    def forward(self,
                cage_v,
                cage_f,
                shape,
                shape_vn,
                epsilon=0.01,
                interpolate=True):
        B, M, D = cage_v.shape
        B, F, _ = cage_f.shape
        B, N, _ = shape.shape
        self.sample_weights = self.sample_weights.to(device=shape.device)
        # B,FF,_ = shape_f.shape
        # sample points using interpolated barycentric weights on cage triangles (B,F,1,3,3)
        cage_face_vertices = torch.gather(
            cage_v, 1,
            cage_f.reshape(B, F * 3,
                           1).expand(-1, -1,
                                     cage_v.shape[-1])).reshape(B, F, 1, 3, 3)
        sample_weights = self.sample_weights.unsqueeze(0).unsqueeze(
            0).unsqueeze(-1).to(device=cage_v.device)  # (1,1,S,3,1)
        # (B,F,S,3)
        cage_sampled_points = torch.sum(sample_weights * cage_face_vertices,
                                        dim=-2).reshape(B, -1, 3)

        # shape_face_vertices = torch.gather(shape, 1, shape_f.view(B,F*3,1)).view(B,F,3,3)

        # find the closest point on the shape
        nn_point, nn_index, _ = faiss_knn(1,
                                          cage_sampled_points,
                                          shape,
                                          NCHW=False)
        nn_point = nn_point.squeeze(2)
        # (B,FS,1)
        nn_normal = torch.gather(
            shape_vn.unsqueeze(1).expand(-1, nn_index.shape[1], -1, -1), 2,
            nn_index.unsqueeze(-1).expand(-1, -1, -1, shape_vn.shape[-1]))
        nn_normal = nn_normal.squeeze(2)

        # if <(q-p), n> is negative, then this point is inside the shape, gradient is along the normal direction
        dot = dot_product(cage_sampled_points - nn_point - epsilon * nn_normal,
                          nn_normal,
                          dim=-1)
        loss = torch.where(dot < 0, -dot, torch.zeros_like(dot))

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "max":
            return torch.mean(torch.max(loss, dim=-1)[0])
        elif self.reduction == "sum":
            return loss.mean(torch.sum(loss, dim=-1))
        elif self.reduction == "none":
            return loss
        else:
            raise NotImplementedError
        return loss
コード例 #3
0
ファイル: losses.py プロジェクト: star-cold/deep_cage
    def forward(self,
                cage,
                shape,
                shape_normals,
                epsilon=0.01,
                interpolate=True):
        """ Penalize polygon cage that is inside the given shape
        Args:
            cage: (B,M,3)
            shape: (B,N,3)
            shape_normals: (B,N,3)
        return:

        """
        B, M, D = cage.shape
        interpolate_n = 10
        # find the closest point on the shape
        cage_p = cage[:, [i for i in range(1, M)] + [0], :]
        t = torch.linspace(0, 1, interpolate_n).to(device=cage_p.device)
        # B,M,K,3
        cage_itp = t.reshape([1, 1, interpolate_n, 1])*cage_p.unsqueeze(2).expand(-1, -1, interpolate_n, -1) + \
            (1-t.reshape([1, 1, interpolate_n, 1]))*cage.unsqueeze(2).expand(-1, -1, interpolate_n, -1)
        cage_itp = cage_itp.reshape(B, -1, D)
        nn_point, nn_index, _ = faiss_knn(1, cage_itp, shape, NCHW=False)
        nn_point = nn_point.squeeze(2)
        nn_normal = torch.gather(
            shape_normals.unsqueeze(1).expand(-1, nn_index.shape[1], -1, -1),
            2,
            nn_index.unsqueeze(-1).expand(-1, -1, -1, shape_normals.shape[-1]))
        nn_normal = nn_normal.squeeze(2)

        # if <(q-p), n> is negative, then this point is inside the shape, gradient is along the normal direction
        dot = dot_product(cage_itp - nn_point - epsilon * nn_normal,
                          nn_normal,
                          dim=-1)
        loss = torch.where(dot < 0, -dot, torch.zeros_like(dot))

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "max":
            return torch.mean(torch.max(loss, dim=-1)[0])
        elif self.reduction == "sum":
            return loss.mean(torch.sum(loss, dim=-1))
        elif self.reduction == "none":
            return loss
        else:
            raise NotImplementedError
        return loss
コード例 #4
0
target_shape.unsqueeze_(0)
orig_label = pd.read_csv(orig_label_path,
                         delimiter=" ",
                         skiprows=1,
                         header=None)
orig_label_name = orig_label.iloc[:, 5]
source_points = torch.from_numpy(orig_label.iloc[:, 6:9].to_numpy().astype(
    np.float32))
source_points = source_points.unsqueeze(0)
# find the closest point on the original meshes
source_mesh = om.read_polymesh(source_model)
# source_mesh = om.read_trimesh(source_model)
source_shape_arr = source_mesh.points()
source_shape = source_shape_arr.copy()
source_shape = torch.from_numpy(source_shape[None, :, :3]).float()
_, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)

target_points = torch.gather(
    target_shape.unsqueeze(1).expand(-1, source_points.shape[1], -1, -1), 2,
    idx.unsqueeze(-1).expand(-1, -1, -1, 3))
# save to pd again
orig_label[9] = idx.squeeze(0).squeeze(-1)
ncol = orig_label.shape[1]
orig_label.to_csv(orig_label_path,
                  sep=" ",
                  header=[str(orig_label.shape[0])] + [""] * (ncol - 1),
                  index=False)
orig_label.iloc[:, 6:9] = target_points.squeeze().numpy()
orig_label.to_csv(new_lable,
                  sep=" ",
                  header=[str(orig_label.shape[0])] + [""] * (ncol - 1),
コード例 #5
0
ファイル: optimize_cage.py プロジェクト: star-cold/deep_cage
def optimize(opt):
    """
    weights are the same with the original source mesh
    target=net(old_source)
    """
    # load new target
    if opt.is_poly:
        target_mesh = om.read_polymesh(opt.model)
    else:
        target_mesh = om.read_trimesh(opt.model)
    target_shape_arr = target_mesh.points()
    target_shape = target_shape_arr.copy()
    target_shape = torch.from_numpy(
        target_shape[:, :3].astype(np.float32)).cuda()
    target_shape.unsqueeze_(0)

    states = torch.load(opt.ckpt)
    if "states" in states:
        states = states["states"]
    cage_v = states["template_vertices"].transpose(1, 2).cuda()
    cage_f = states["template_faces"].cuda()
    shape_v = states["source_vertices"].transpose(1, 2).cuda()
    shape_f = states["source_faces"].cuda()

    if os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path))
        import pandas as pd
        new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        new_label_name = new_label.iloc[:,5].tolist()
        new_to_orig_idx = []
        for i, name in enumerate(new_label_name):
            matched_idx = orig_label_name[orig_label_name==name].index
            if matched_idx.size == 1:
                new_to_orig_idx.append((i, matched_idx[0]))
        new_to_orig_idx = np.array(new_to_orig_idx)
        if new_label.shape[1] == 10:
            new_vidx = new_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,0]]
            target_points = target_shape[:, new_vidx, :]
        else:
            new_label_points = torch.from_numpy(new_label.iloc[:,6:9].to_numpy().astype(np.float32))
            target_points = new_label_points.unsqueeze(0).cuda()
            target_points, new_vidx, _ = faiss_knn(1, target_points, target_shape, NCHW=False)
            target_points = target_points.squeeze(2) # B,N,3
            new_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            new_label.to_csv(new_label_path, sep=" ", header=[str(new_label.shape[0])]+[""]*(new_label.shape[1]-1), index=False)
            target_points = target_points[:, new_to_orig_idx[:,0], :]

        target_points = target_points.cuda()
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).float()
        if orig_label.shape[1] == 10:
            orig_vidx = orig_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,1]]
            source_points = source_shape[:, orig_vidx, :]
        else:
            orig_label_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = orig_label_points.unsqueeze(0)
            # find the closest point on the original meshes
            source_points, new_vidx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            orig_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy()
            orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])]+[""]*(orig_label.shape[1]-1), index=False)
            source_points = source_points[:,new_to_orig_idx[:,1],:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
        source_points = source_points.cuda()
        # # shift target so that the belly match
        # try:
        #     orig_bellyUp_idx = orig_label_name[orig_label_name=="bellUp"].index[0]
        #     orig_bellyUp = orig_label_points[orig_bellyUp_idx, :]
        #     new_bellyUp_idx = [i for i, i2 in new_to_orig_idx if i2==orig_bellyUp_idx][0]
        #     new_bellyUp = new_label_points[new_bellyUp_idx,:]
        #     target_points += (orig_bellyUp - new_bellyUp)
        # except Exception as e:
        #     logger.warn("Couldn\'t match belly to belly")
        #     traceback.print_exc(file=sys.stdout)

        # source_points[0] = center_bounding_box(source_points[0])[0]
    elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
        logger.info("Assuming Faust model")
        orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
        logger.info("Loading picked labels {}".format(orig_label_path))
        import pandas as pd
        orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
        orig_label_name = orig_label.iloc[:,5]
        source_shape, _ = read_trimesh(opt.source_model)
        source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float()
        if orig_label.shape[1] == 10:
            idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long()
            source_points = source_shape[:,idx,:]
            target_points = target_shape[:,idx,:]
        else:
            source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
            source_points = source_points.unsqueeze(0).cuda()
            # find the closest point on the original meshes
            source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
            source_points = source_points.squeeze(2) # B,N,3
            idx = idx.squeeze(-1)
            target_points = target_shape[:,idx,:]

        _, source_center, _ = center_bounding_box(source_shape[0])
        source_points -= source_center
    elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]:
        logger.info("No correspondence provided, assuming registered Faust models")
        # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
        corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda()
        target_points = torch.index_select(target_shape, 1, corresp_v)
        source_points = torch.index_select(shape_v, 1, corresp_v)

    target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0])
    _, _, source_scale = center_bounding_box(shape_v[0])
    target_scale_factor = (source_scale/target_scale)[1]
    target_shape *= target_scale_factor
    target_points -= target_center
    target_points = (target_points*target_scale_factor).detach()
    # make sure test use the normalized
    target_shape_arr[:] = target_shape[0].cpu().numpy()
    om.write_mesh(os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj"), target_mesh)
    opt.model = os.path.join(opt.log_dir, opt.subdir, os.path.splitext(
        os.path.basename(opt.model))[0]+"_normalized.obj")
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.obj"),
                         shape_v[0].cpu().numpy(), shape_f[0].cpu().numpy())
    pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "cage-initial.obj"),
                         cage_v[0].cpu().numpy(), cage_f[0].cpu().numpy())
    save_ply(target_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "target_points.ply"))
    save_ply(source_points[0].cpu().numpy(), os.path.join(
        opt.log_dir, opt.subdir, "source_points.ply"))
    logger.info("Optimizing for {} corresponding vertices".format(
        target_points.shape[1]))

    cage_init = cage_v.clone().detach()
    lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True,
                                 use_norm=True, consistent_topology=True, precompute_L=True)
    mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0)
    cage_v.requires_grad_(True)
    optimizer = torch.optim.Adam([cage_v], lr=opt.lr, betas=(0.5, 0.9))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs*0.4), gamma=0.5, last_epoch=-1)

    if opt.dim == 3:
        weights_ref = mean_value_coordinates_3D(
            source_points, cage_init, cage_f, verbose=False)
    else:
        raise NotImplementedError

    for t in range(opt.nepochs):
        optimizer.zero_grad()
        weights = mean_value_coordinates_3D(
            target_points, cage_v, cage_f, verbose=False)
        loss_mvc = torch.mean((weights-weights_ref)**2)
        # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
        reg = 0
        if opt.clap_weight > 0:
            reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight
            reg = reg.mean()
        if opt.mvc_weight > 0:
            reg += mvc_reg_loss(weights)*opt.mvc_weight

        # weight regularizer with the shape difference
        # dist = torch.sum((source_points - target_points)**2, dim=-1)
        # weights = torch.exp(-dist)
        # reg = reg*weights*0.1

        loss = loss_mvc + reg
        if (t+1) % 50 == 0:
            print("t {}/{} mvc_loss: {} reg: {}".format(t,
                                                        opt.nepochs, loss_mvc.item(), reg.item()))

        if loss_mvc.item() < 5e-6:
            break
        loss.backward()
        optimizer.step()
        scheduler.step()

    return cage_v, cage_f
コード例 #6
0
ファイル: camera.py プロジェクト: zuru/DSS
    def __init__(self,
                 nCam,
                 offset,
                 focalLength,
                 device=None,
                 points=None,
                 normals=None,
                 camWidth=256,
                 camHeight=256,
                 filename="../example_data/pointclouds/sphere_300.ply",
                 closer=True):
        """
        create camera position from a sphere around shape with descreasing distance
        input:
            nCam:           total number of cameras
            offset:         a number distance to shape surface
            focalLength:    a number
            (optional) points (B,N,3or4)
        allPositions (B,C,3)
        allRotations (B,C,3,3)
        """
        if device is None:
            if points is not None:
                self.device = points.device
            else:
                self.device = torch.cuda.current_device()
        else:
            self.device = device
        self.closer = closer
        if filename is not None:
            self.allPositions = torch.from_numpy(read_ply(
                filename, nCam)).to(device=self.device)[:, :3]
            self.allPositions = self.allPositions.unsqueeze(0)
        else:
            sampleIdx, self.allPositions = operations.furthest_point_sample(
                points.cuda(), nCam, NCHW=False)
            self.allPositions = self.allPositions.to(self.device)
            if normals is not None:
                _, idx, _ = operations.faiss_knn(100,
                                                 self.allPositions.cpu(),
                                                 points.cpu(),
                                                 NCHW=False)
                knn_normals = torch.gather(
                    normals.unsqueeze(1).expand(-1, self.allPositions.shape[1],
                                                -1, -1), 2,
                    idx.unsqueeze(-1).expand(-1, -1, -1, normals.shape[-1]))
                normals = torch.mean(knn_normals, dim=2).to(self.device)

        if points is not None:
            if points.dim() == 2:
                points = points.unsqueeze(0)
            maxP = torch.max(points, dim=1, keepdim=True)[0]
            minP = torch.min(points, dim=1, keepdim=True)[0]
            bb = maxP - minP
            offset = offset + bb
            if normals is not None:
                center = self.allPositions
                # self.allPositions = (torch.mean(normals, dim=1, keepdim=True))
                self.allPositions = normals + (torch.mean(
                    normals, dim=1, keepdim=True))
                self.allPositions += torch.randn_like(self.allPositions) * 0.01
            else:
                center = torch.mean(points, dim=1, keepdim=True)
        else:
            center = torch.zeros([1, 1, 3],
                                 dtype=self.allPositions.dtype,
                                 device=self.allPositions.device)

        self.allPositions = self.allPositions * offset
        self.allPositions = center + self.allPositions
        # Bx1x3
        self.to = center.expand_as(self.allPositions)
        # BxNx3
        # self.ups = torch.tensor([0, 1, 0], dtype=self.to.dtype, device=self.to.device).view(1, 1, 3).expand_as(self.allPositions)
        # for sketchfab
        self.ups = torch.tensor([0, 0, 1],
                                dtype=self.to.dtype,
                                device=self.to.device).view(1, 1, 3).expand_as(
                                    self.allPositions)
        self.ups = self.ups + torch.randn_like(self.ups) * 0.0001
        self.rotation, self.position = batchLookAt(self.allPositions, self.to,
                                                   self.ups)
        self.idx = 0
        self.length = self.rotation.shape[1]
        self.focalLength = focalLength
        self.camWidth = camWidth
        self.camHeight = camHeight