コード例 #1
0
 def forward(self, data, sub_x, sub_y):
     '''
     data: [N, H, W, C] or [1, H, W, C]
     sub_x: [N, ...]
     sub_y: [N, ...]
     return: [N, ..., C]
     '''
     if data.shape[0] == 1:
         return misc.interpolate_bilinear(data[0, :], sub_x,
                                          sub_y)  # [N, ..., C]
     elif data.shape[0] == sub_x.shape[0]:
         out = []
         for i in range(data.shape[0]):
             out.append(
                 misc.interpolate_bilinear(data[i, :], sub_x[i, :],
                                           sub_y[i, :]))  # [..., C]
         return torch.stack(out)  # [N, ..., C]
     else:
         raise ValueError('data.shape[0] should be 1 or batch size')
コード例 #2
0
ファイル: render.py プロジェクト: rodrygojose/relightable-nr
def lp_mapping(lp, dir_map, alpha_map):
    '''
    lp: torch.FloatTensor, [H, W, C]
    dir_map: torch.FloatTensor, [3, ...]
    alpha_map: torch.FloatTensor, [1, ...] or [3, ...]
    return: torch.FloatTensor, [..., C]
    '''
    uv_map = render.spherical_mapping(dir_map)  # [2, ...]
    uv_map = uv_map * alpha_map - (alpha_map == 0).to(
        dir_map.dtype)  # [2, ...], mask out unused regions
    sample_img = misc.interpolate_bilinear(
        lp, uv_map[0, :] * float(lp.shape[1] - 1),
        uv_map[1, :] * float(lp.shape[0] - 1))
    return sample_img
コード例 #3
0
    def forward(self, uv_map, sh_basis_map=None, sh_start_ch=3):
        '''
        uv_map: [N, H, W, C]
        sh_basis_map: [N, H, W, 9]
        return: [N, C, H, W]
        '''
        for ithLevel in range(self.mipmap_level):
            texture_size_i = self.textures_size[ithLevel]
            texture_i = self.textures[ithLevel]

            # vertex texcoords map in unit of texel
            uv_map_unit_texel = (uv_map * (texture_size_i - 1))
            uv_map_unit_texel[...,
                              -1] = texture_size_i - 1 - uv_map_unit_texel[...,
                                                                           -1]

            # sample from texture (bilinear)
            if ithLevel == 0:
                output = misc.interpolate_bilinear(
                    texture_i[0, :], uv_map_unit_texel[..., 0],
                    uv_map_unit_texel[..., 1]).permute(
                        (0, 3, 1, 2))  # [N, C, H, W]
            else:
                output = output + misc.interpolate_bilinear(
                    texture_i[0, :], uv_map_unit_texel[..., 0],
                    uv_map_unit_texel[..., 1]).permute(
                        (0, 3, 1, 2))  # [N, C, H, W]

        # apply spherical harmonics
        if self.apply_sh and sh_basis_map is not None:
            output[:, sh_start_ch:sh_start_ch +
                   9, :, :] = output[:, sh_start_ch:sh_start_ch +
                                     9, :, :] * sh_basis_map.permute(
                                         (0, 3, 1, 2))

        return output
コード例 #4
0
ファイル: render.py プロジェクト: rodrygojose/relightable-nr
def texture_mapping(texture, uv_map):
    '''
    texture: torch.FloatTensor, [H, W, C]
    uv_map: torch.FloatTensor, [N, H, W, 2]
    return: torch.FloatTensor, [N, H, W, C]
    '''
    tex_h = texture.shape[0] * 1.0
    tex_w = texture.shape[1] * 1.0

    uv_map_unit_texel = uv_map
    uv_map_unit_texel[..., 0] = uv_map_unit_texel[..., 0] * (tex_w - 1)
    uv_map_unit_texel[..., 1] = uv_map_unit_texel[..., 1] * (tex_h - 1)
    uv_map_unit_texel[..., 1] = tex_h - 1 - uv_map_unit_texel[..., 1]
    img = misc.interpolate_bilinear(texture, uv_map_unit_texel[..., 0],
                                    uv_map_unit_texel[..., 1])

    return img
コード例 #5
0
    def __init__(self,
                 l_dir,
                 num_lighting=1,
                 num_channel=3,
                 lp_dataloader=None,
                 fix_params=False,
                 lp_img_h=1600,
                 lp_img_w=3200):
        '''
        l_dir: torch.FloatTensor, [3, num_sample], sampled light directions
        num_lighting: int, number of lighting
        num_channel: int, number of color channels
        lp_dataloader: dataloader for light probes (if not None, num_lighting is ignored)
        fix_params: bool, whether fix parameters
        '''
        super().__init__()

        self.register_buffer('l_dir', l_dir)  # [3, num_sample]
        self.num_sample = l_dir.shape[1]
        self.num_lighting = num_lighting
        self.num_channel = num_channel
        self.fix_params = fix_params
        self.lp_img_h = lp_img_h
        self.lp_img_w = lp_img_w

        if lp_dataloader is not None:
            self.num_lighting = len(lp_dataloader)

        # spherical mapping to get light probe uv
        l_samples_uv = render.spherical_mapping(l_dir)
        self.register_buffer('l_samples_uv', l_samples_uv)  # [2, num_sample]

        # light samples as learnable parameters
        self.l_samples = nn.Parameter(
            torch.zeros((self.num_lighting, self.num_sample, self.num_channel),
                        dtype=torch.float32)
        )  # [num_lighting, num_sample, num_channel]

        # initialize light samples from light probes
        if lp_dataloader is not None:
            self.num_lighting = len(lp_dataloader)
            lp_idx = 0
            lps = []
            for lp in lp_dataloader:
                lp_img = lp['lp_img'][0, :].permute((1, 2, 0))
                lps.append(
                    torch.from_numpy(
                        cv2.resize(lp_img.cpu().detach().numpy(),
                                   (lp_img_w, lp_img_h),
                                   interpolation=cv2.INTER_AREA)))  # [H, W, C]
                lp_img = lps[-1]
                self.l_samples.data[lp_idx, :] = misc.interpolate_bilinear(
                    lp_img.to(self.l_samples_uv.device),
                    (self.l_samples_uv[None, 0, :] *
                     float(lp_img.shape[1])).clamp(max=lp_img.shape[1] - 1),
                    (self.l_samples_uv[None, 1, :] *
                     float(lp_img.shape[0])).clamp(max=lp_img.shape[0] -
                                                   1))[0, :]
                lp_idx += 1

            lps = torch.stack(lps)
            self.register_buffer('lps', lps)  # [num_lighting, H, W, C]

        # change to non-learnable
        if self.fix_params:
            self.l_samples.requires_grad_(False)
コード例 #6
0
    def forward(self, proj, pose, dist_coeffs, offset, scale):
        _, depth, alpha, face_index_map, weight_map, v_uvz, faces_v_uvz, faces_v_idx = self.renderer(
            self.vertices,
            self.faces,
            torch.tanh(self.textures),
            K=proj,
            R=pose[:, :3, :3],
            t=pose[:, :3, -1, None].permute(0, 2, 1),
            dist_coeffs=dist_coeffs,
            offset=offset,
            scale=scale)
        batch_size = face_index_map.shape[0]
        image_size = face_index_map.shape[1]

        # find indices of vertices on frontal face
        v_uvz[..., 0] = (v_uvz[..., 0] * 0.5 +
                         0.5) * depth.shape[2]  # [1, num_vertex]
        v_uvz[..., 1] = (
            1 -
            (v_uvz[..., 1] * 0.5 + 0.5)) * depth.shape[1]  # [1, num_vertex]
        v_depth = misc.interpolate_bilinear(depth[0, :, :, None], v_uvz[...,
                                                                        0],
                                            v_uvz[...,
                                                  1])  # [1, num_vertex, 1]
        v_front_mask = ((v_uvz[0, :, 2] - v_depth[0, :, 0]) <
                        self.mesh_span * 5e-3)[None, :]  # [1, num_vertex]

        # perspective correct weight
        faces_v_z_inv_map = torch.cuda.FloatTensor(batch_size, image_size,
                                                   image_size, 3).fill_(0.0)
        for i in range(batch_size):
            faces_v_z_inv_map[i, ...] = 1 / faces_v_uvz[i, face_index_map[
                i, ...].long()][..., -1]
        weight_map = (faces_v_z_inv_map * weight_map) * depth.unsqueeze_(
            -1)  # [batch_size, image_size, image_size, 3]
        weight_map = weight_map.unsqueeze_(
            -1)  # [batch_size, image_size, image_size, 3, 1]

        # uv map
        if self.renderer.fill_back:
            faces_vt_idx = torch.cat(
                (self.faces_vt_idx, self.
                 faces_vt_idx[:, :,
                              list(reversed(range(self.faces_vt_idx.shape[-1]))
                                   )]),
                dim=1).detach()
        else:
            faces_vt_idx = self.faces_vt_idx.detach()
        faces_vt = nr.vertex_attrs_to_faces(
            self.vertices_texcoords, faces_vt_idx)  # [1, num_face, 3, 2]
        uv_map = faces_vt[:, face_index_map.long()].squeeze_(
            0
        )  # [batch_size, image_size, image_size, 3, 2], before weighted combination
        uv_map = (uv_map * weight_map).sum(
            -2
        )  # [batch_size, image_size, image_size, 2], after weighted combination
        uv_map = uv_map - uv_map.floor(
        )  # handle uv_map wrapping, keep uv in [0, 1]

        # normal map in world space
        if self.renderer.fill_back:
            faces_vn_idx = torch.cat(
                (self.faces_vn_idx, self.
                 faces_vn_idx[:, :,
                              list(reversed(range(self.faces_vn_idx.shape[-1]))
                                   )]),
                dim=1).detach()
        else:
            faces_vn_idx = self.faces_vn_idx.detach()
        faces_vn = nr.vertex_attrs_to_faces(
            self.vertices_normals, faces_vn_idx)  # [1, num_face, 3, 3]
        normal_map = faces_vn[:, face_index_map.long()].squeeze_(
            0
        )  # [batch_size, image_size, image_size, 3, 3], before weighted combination
        normal_map = (normal_map * weight_map).sum(
            -2
        )  # [batch_size, image_size, image_size, 3], after weighted combination
        normal_map = torch.nn.functional.normalize(normal_map, dim=-1)

        # normal_map in camera space
        normal_map_flat = normal_map.flatten(start_dim=1, end_dim=2).permute(
            (0, 2, 1))
        normal_map_cam = pose[:, :3, :3].matmul(normal_map_flat).permute(
            (0, 2, 1)).reshape(normal_map.shape)
        normal_map_cam = torch.nn.functional.normalize(normal_map_cam, dim=-1)

        # position_map in world space
        faces_v = nr.vertex_attrs_to_faces(self.vertices,
                                           faces_v_idx)  # [1, num_face, 3, 3]
        position_map = faces_v[0, face_index_map.long(
        )]  # [batch_size, image_size, image_size, 3, 3], before weighted combination
        position_map = (position_map * weight_map).sum(
            -2
        )  # [batch_size, image_size, image_size, 3], after weighted combination

        # position_map in camera space
        position_map_flat = position_map.flatten(start_dim=1,
                                                 end_dim=2).permute((0, 2, 1))
        position_map_cam = pose[:, :3, :3].matmul(position_map_flat).permute(
            (0, 2, 1)).reshape(position_map.shape) + pose[:, :3, -1][:, None,
                                                                     None, :]

        return uv_map, alpha, face_index_map, weight_map, faces_v_idx, normal_map, normal_map_cam, faces_v, faces_vt, position_map, position_map_cam, depth, v_uvz, v_front_mask