def forward(self, data, sub_x, sub_y): ''' data: [N, H, W, C] or [1, H, W, C] sub_x: [N, ...] sub_y: [N, ...] return: [N, ..., C] ''' if data.shape[0] == 1: return misc.interpolate_bilinear(data[0, :], sub_x, sub_y) # [N, ..., C] elif data.shape[0] == sub_x.shape[0]: out = [] for i in range(data.shape[0]): out.append( misc.interpolate_bilinear(data[i, :], sub_x[i, :], sub_y[i, :])) # [..., C] return torch.stack(out) # [N, ..., C] else: raise ValueError('data.shape[0] should be 1 or batch size')
def lp_mapping(lp, dir_map, alpha_map): ''' lp: torch.FloatTensor, [H, W, C] dir_map: torch.FloatTensor, [3, ...] alpha_map: torch.FloatTensor, [1, ...] or [3, ...] return: torch.FloatTensor, [..., C] ''' uv_map = render.spherical_mapping(dir_map) # [2, ...] uv_map = uv_map * alpha_map - (alpha_map == 0).to( dir_map.dtype) # [2, ...], mask out unused regions sample_img = misc.interpolate_bilinear( lp, uv_map[0, :] * float(lp.shape[1] - 1), uv_map[1, :] * float(lp.shape[0] - 1)) return sample_img
def forward(self, uv_map, sh_basis_map=None, sh_start_ch=3): ''' uv_map: [N, H, W, C] sh_basis_map: [N, H, W, 9] return: [N, C, H, W] ''' for ithLevel in range(self.mipmap_level): texture_size_i = self.textures_size[ithLevel] texture_i = self.textures[ithLevel] # vertex texcoords map in unit of texel uv_map_unit_texel = (uv_map * (texture_size_i - 1)) uv_map_unit_texel[..., -1] = texture_size_i - 1 - uv_map_unit_texel[..., -1] # sample from texture (bilinear) if ithLevel == 0: output = misc.interpolate_bilinear( texture_i[0, :], uv_map_unit_texel[..., 0], uv_map_unit_texel[..., 1]).permute( (0, 3, 1, 2)) # [N, C, H, W] else: output = output + misc.interpolate_bilinear( texture_i[0, :], uv_map_unit_texel[..., 0], uv_map_unit_texel[..., 1]).permute( (0, 3, 1, 2)) # [N, C, H, W] # apply spherical harmonics if self.apply_sh and sh_basis_map is not None: output[:, sh_start_ch:sh_start_ch + 9, :, :] = output[:, sh_start_ch:sh_start_ch + 9, :, :] * sh_basis_map.permute( (0, 3, 1, 2)) return output
def texture_mapping(texture, uv_map): ''' texture: torch.FloatTensor, [H, W, C] uv_map: torch.FloatTensor, [N, H, W, 2] return: torch.FloatTensor, [N, H, W, C] ''' tex_h = texture.shape[0] * 1.0 tex_w = texture.shape[1] * 1.0 uv_map_unit_texel = uv_map uv_map_unit_texel[..., 0] = uv_map_unit_texel[..., 0] * (tex_w - 1) uv_map_unit_texel[..., 1] = uv_map_unit_texel[..., 1] * (tex_h - 1) uv_map_unit_texel[..., 1] = tex_h - 1 - uv_map_unit_texel[..., 1] img = misc.interpolate_bilinear(texture, uv_map_unit_texel[..., 0], uv_map_unit_texel[..., 1]) return img
def __init__(self, l_dir, num_lighting=1, num_channel=3, lp_dataloader=None, fix_params=False, lp_img_h=1600, lp_img_w=3200): ''' l_dir: torch.FloatTensor, [3, num_sample], sampled light directions num_lighting: int, number of lighting num_channel: int, number of color channels lp_dataloader: dataloader for light probes (if not None, num_lighting is ignored) fix_params: bool, whether fix parameters ''' super().__init__() self.register_buffer('l_dir', l_dir) # [3, num_sample] self.num_sample = l_dir.shape[1] self.num_lighting = num_lighting self.num_channel = num_channel self.fix_params = fix_params self.lp_img_h = lp_img_h self.lp_img_w = lp_img_w if lp_dataloader is not None: self.num_lighting = len(lp_dataloader) # spherical mapping to get light probe uv l_samples_uv = render.spherical_mapping(l_dir) self.register_buffer('l_samples_uv', l_samples_uv) # [2, num_sample] # light samples as learnable parameters self.l_samples = nn.Parameter( torch.zeros((self.num_lighting, self.num_sample, self.num_channel), dtype=torch.float32) ) # [num_lighting, num_sample, num_channel] # initialize light samples from light probes if lp_dataloader is not None: self.num_lighting = len(lp_dataloader) lp_idx = 0 lps = [] for lp in lp_dataloader: lp_img = lp['lp_img'][0, :].permute((1, 2, 0)) lps.append( torch.from_numpy( cv2.resize(lp_img.cpu().detach().numpy(), (lp_img_w, lp_img_h), interpolation=cv2.INTER_AREA))) # [H, W, C] lp_img = lps[-1] self.l_samples.data[lp_idx, :] = misc.interpolate_bilinear( lp_img.to(self.l_samples_uv.device), (self.l_samples_uv[None, 0, :] * float(lp_img.shape[1])).clamp(max=lp_img.shape[1] - 1), (self.l_samples_uv[None, 1, :] * float(lp_img.shape[0])).clamp(max=lp_img.shape[0] - 1))[0, :] lp_idx += 1 lps = torch.stack(lps) self.register_buffer('lps', lps) # [num_lighting, H, W, C] # change to non-learnable if self.fix_params: self.l_samples.requires_grad_(False)
def forward(self, proj, pose, dist_coeffs, offset, scale): _, depth, alpha, face_index_map, weight_map, v_uvz, faces_v_uvz, faces_v_idx = self.renderer( self.vertices, self.faces, torch.tanh(self.textures), K=proj, R=pose[:, :3, :3], t=pose[:, :3, -1, None].permute(0, 2, 1), dist_coeffs=dist_coeffs, offset=offset, scale=scale) batch_size = face_index_map.shape[0] image_size = face_index_map.shape[1] # find indices of vertices on frontal face v_uvz[..., 0] = (v_uvz[..., 0] * 0.5 + 0.5) * depth.shape[2] # [1, num_vertex] v_uvz[..., 1] = ( 1 - (v_uvz[..., 1] * 0.5 + 0.5)) * depth.shape[1] # [1, num_vertex] v_depth = misc.interpolate_bilinear(depth[0, :, :, None], v_uvz[..., 0], v_uvz[..., 1]) # [1, num_vertex, 1] v_front_mask = ((v_uvz[0, :, 2] - v_depth[0, :, 0]) < self.mesh_span * 5e-3)[None, :] # [1, num_vertex] # perspective correct weight faces_v_z_inv_map = torch.cuda.FloatTensor(batch_size, image_size, image_size, 3).fill_(0.0) for i in range(batch_size): faces_v_z_inv_map[i, ...] = 1 / faces_v_uvz[i, face_index_map[ i, ...].long()][..., -1] weight_map = (faces_v_z_inv_map * weight_map) * depth.unsqueeze_( -1) # [batch_size, image_size, image_size, 3] weight_map = weight_map.unsqueeze_( -1) # [batch_size, image_size, image_size, 3, 1] # uv map if self.renderer.fill_back: faces_vt_idx = torch.cat( (self.faces_vt_idx, self. faces_vt_idx[:, :, list(reversed(range(self.faces_vt_idx.shape[-1])) )]), dim=1).detach() else: faces_vt_idx = self.faces_vt_idx.detach() faces_vt = nr.vertex_attrs_to_faces( self.vertices_texcoords, faces_vt_idx) # [1, num_face, 3, 2] uv_map = faces_vt[:, face_index_map.long()].squeeze_( 0 ) # [batch_size, image_size, image_size, 3, 2], before weighted combination uv_map = (uv_map * weight_map).sum( -2 ) # [batch_size, image_size, image_size, 2], after weighted combination uv_map = uv_map - uv_map.floor( ) # handle uv_map wrapping, keep uv in [0, 1] # normal map in world space if self.renderer.fill_back: faces_vn_idx = torch.cat( (self.faces_vn_idx, self. faces_vn_idx[:, :, list(reversed(range(self.faces_vn_idx.shape[-1])) )]), dim=1).detach() else: faces_vn_idx = self.faces_vn_idx.detach() faces_vn = nr.vertex_attrs_to_faces( self.vertices_normals, faces_vn_idx) # [1, num_face, 3, 3] normal_map = faces_vn[:, face_index_map.long()].squeeze_( 0 ) # [batch_size, image_size, image_size, 3, 3], before weighted combination normal_map = (normal_map * weight_map).sum( -2 ) # [batch_size, image_size, image_size, 3], after weighted combination normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # normal_map in camera space normal_map_flat = normal_map.flatten(start_dim=1, end_dim=2).permute( (0, 2, 1)) normal_map_cam = pose[:, :3, :3].matmul(normal_map_flat).permute( (0, 2, 1)).reshape(normal_map.shape) normal_map_cam = torch.nn.functional.normalize(normal_map_cam, dim=-1) # position_map in world space faces_v = nr.vertex_attrs_to_faces(self.vertices, faces_v_idx) # [1, num_face, 3, 3] position_map = faces_v[0, face_index_map.long( )] # [batch_size, image_size, image_size, 3, 3], before weighted combination position_map = (position_map * weight_map).sum( -2 ) # [batch_size, image_size, image_size, 3], after weighted combination # position_map in camera space position_map_flat = position_map.flatten(start_dim=1, end_dim=2).permute((0, 2, 1)) position_map_cam = pose[:, :3, :3].matmul(position_map_flat).permute( (0, 2, 1)).reshape(position_map.shape) + pose[:, :3, -1][:, None, None, :] return uv_map, alpha, face_index_map, weight_map, faces_v_idx, normal_map, normal_map_cam, faces_v, faces_vt, position_map, position_map_cam, depth, v_uvz, v_front_mask