def project_vertex_render( ori_img, norm_image, clip_xyzw, tri, imageH, imageW, ver_rgb, ver_mask, para_illum, var_scope_name, ): with tf.variable_scope(var_scope_name): batch_size, _, _ = clip_xyzw.get_shape().as_list() aug_ver_attrs = tf.concat([ver_rgb, ver_mask], axis=2) attrs, _ = rasterize_clip_space( clip_xyzw, aug_ver_attrs, tri, imageW, imageH, -1.0 ) # Have shading diffuse_image = tf.reshape( attrs[:, :, :, :3], [batch_size, imageH, imageW, 3] ) alphas = tf.reshape(attrs[:, :, :, 3:], [batch_size, imageH, imageW, 1]) rgb_images, shading_image = Shader.sh_shader( norm_image, alphas, ori_img, para_illum, diffuse_image ) ori_img_remove_shading = ori_img / shading_image diffuse_image = tf.clip_by_value(diffuse_image, 0, 1) rgb_images = tf.clip_by_value(rgb_images, 0, 1) attrs_image = tf.clip_by_value(alphas, 0, 1) ori_img_remove_shading = tf.clip_by_value(ori_img_remove_shading, 0, 1) render_image = rgb_images * attrs_image + ori_img * (1 - attrs_image) return render_image, attrs_image, ori_img_remove_shading
def generate_proj_information( ver_xyz, trans_Mat, K_img, imageH, imageW, tri, project_type="Pers", name="ver_norm_and_ver_depth", ): ver_w = tf.ones_like(ver_xyz[:, :, 0:1], name="ver_w") ver_xyzw = tf.concat([ver_xyz, ver_w], axis=2) # 1, 20481, 4 vertex_img = tf.matmul(ver_xyzw, trans_Mat) # 1 x 20481 x 4 cam_xyz = vertex_img[:, :, 0:3] # 1 x 20481 x 3 K_img = tf.transpose(K_img, [0, 2, 1]) # 1 x 3 x 3 proj_xyz_batch = tf.matmul(cam_xyz, K_img) # 1 x 20481 x 3 proj_xyz_depth_batch = tf.matmul(cam_xyz, K_img) # 1 x 20481 x 3 if project_type == "Orth": clip_x = tf.expand_dims( (proj_xyz_batch[:, :, 0] + imageW / 2) / imageW * 2 - 1, axis=2 ) # 1 x 20481 x 1 clip_y = tf.expand_dims( (proj_xyz_batch[:, :, 1] + imageH / 2) / imageH * 2 - 1, axis=2 ) # 1 x 20481 x 1 else: clip_x = tf.expand_dims( (proj_xyz_batch[:, :, 0] / proj_xyz_batch[:, :, 2]) / imageW * 2 - 1, axis=2, ) # 1 x 20481 x 1 clip_y = tf.expand_dims( (proj_xyz_batch[:, :, 1] / proj_xyz_batch[:, :, 2]) / imageH * 2 - 1, axis=2, ) # 1 x 20481 x 1 clip_z = tf.expand_dims( tf.nn.l2_normalize(proj_xyz_batch[:, :, 2], dim=1, epsilon=1e-10), axis=2 ) clip_xyz = tf.concat([clip_x, clip_y, clip_z], axis=2) # 1, 20481, 3 clip_w = tf.ones_like(clip_xyz[:, :, 0:1], name="clip_w") clip_xyzw = tf.concat([clip_xyz, clip_w], axis=2) # 1, 20481, 4 if project_type == "Orth": proj_x = tf.expand_dims( proj_xyz_batch[:, :, 0] + imageW / 2, axis=2 ) # 1 x 20481 x 1 proj_y = tf.expand_dims(proj_xyz_batch[:, :, 1] + imageH / 2, axis=2) else: proj_x = tf.expand_dims( proj_xyz_batch[:, :, 0] / proj_xyz_batch[:, :, 2], axis=2 ) # 1 x 20481 x 1 proj_y = tf.expand_dims( proj_xyz_batch[:, :, 1] / proj_xyz_batch[:, :, 2], axis=2 ) # 1 x 20481 x 1 proj_z = tf.expand_dims(proj_xyz_batch[:, :, 2], axis=2) proj_xy = tf.concat([proj_x, proj_y], axis=2) # 1, 20481, 2 depth_infor = tf.expand_dims( proj_xyz_depth_batch[:, :, 2], axis=2 ) # 1 x 20481 x 1 with tf.variable_scope(name): ver_norm, ver_contour_mask = Projector.get_ver_norm(cam_xyz, tri) norm_depth_infro = tf.concat( [ver_norm, depth_infor, ver_contour_mask], axis=2 ) # 1, 20481, 4 norm_depth_image, alphas = rasterize_clip_space( clip_xyzw, norm_depth_infro, tri, imageW, imageH, 0.0 ) norm_image = norm_depth_image[:, :, :, 0:3] # (300,300) depth_image = tf.expand_dims(norm_depth_image[:, :, :, 3], 3) # (300,300) ver_contour_mask_image = tf.expand_dims( norm_depth_image[:, :, :, 4], 3 ) # (300,300) return ( norm_image, ver_norm, alphas, clip_xyzw, proj_xy, proj_z, depth_image, ver_contour_mask, ver_contour_mask_image, )
def project_uv_render( ori_img, norm_image, clip_xyzw, tri, tri_vt, vt_list, imageH, imageW, uv_rgb, uv_mask, para_illum, var_scope_name, ): batch_size, _, _ = clip_xyzw.get_shape().as_list() # get uv coordinates V, U = tf.split(vt_list, 2, axis=1) uv_size = uv_rgb.get_shape().as_list()[1] U = (1.0 - U) * uv_size V = V * uv_size UV = tf.concat([U, V], axis=1) batch_UV = tf.tile(UV, [batch_size, 1]) # get clip_xyzw for ver_uv (according to the correspondence between tri and tri_vt) # gather and scatter EPS = 1e-12 batch_tri_indices = tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), axis=1), [1, len(tri_vt) * 3]), [-1], name="batch_tri_indices", ) tri_inds = tf.stack( [ batch_tri_indices, tf.concat([tf.reshape(tri, [len(tri) * 3])] * batch_size, axis=0), ], axis=1, ) tri_vt_inds = tf.stack( [ batch_tri_indices, tf.concat([tf.reshape(tri_vt, [len(tri_vt) * 3])] * batch_size, axis=0), ], axis=1, ) tri_clip_xyzw = tf.gather_nd(clip_xyzw, tri_inds, name="tri_clip_xyzw") ver_uv_clip_xyzw_sum = tf.get_variable( shape=[batch_size, len(vt_list), 4], dtype=tf.float32, initializer=tf.zeros_initializer(), name=var_scope_name + "ver_uv_clip_xyzw_sum", trainable=False, ) ver_uv_clip_xyzw_cnt = tf.get_variable( shape=[batch_size, len(vt_list), 4], dtype=tf.float32, initializer=tf.zeros_initializer(), name=var_scope_name + "ver_uv_clip_xyzw_cnt", trainable=False, ) init_ver_uv = tf.zeros(shape=[batch_size, len(vt_list), 4], dtype=tf.float32) assign_op1 = tf.assign(ver_uv_clip_xyzw_sum, init_ver_uv) assign_op2 = tf.assign(ver_uv_clip_xyzw_cnt, init_ver_uv) with tf.control_dependencies([assign_op1, assign_op2]): ver_uv_clip_xyzw_sum = tf.scatter_nd_add( ver_uv_clip_xyzw_sum, tri_vt_inds, tri_clip_xyzw ) ver_uv_clip_xyzw_cnt = tf.scatter_nd_add( ver_uv_clip_xyzw_cnt, tri_vt_inds, tf.ones_like(tri_clip_xyzw) ) ver_uv_clip_xyzw = tf.div(ver_uv_clip_xyzw_sum, ver_uv_clip_xyzw_cnt + EPS) uv_image, uv_alphas = rasterize_clip_space( ver_uv_clip_xyzw, batch_UV, tri_vt, imageW, imageH, -1.0 ) uv_image = tf.clip_by_value( tf.cast(uv_image, tf.int32), 0, 511 ) # should be integer batch_vt_indices = tf.reshape( tf.tile( tf.expand_dims(tf.range(batch_size), axis=1), [1, imageW * imageH] ), [-1, 1], name="batch_indices", ) batch_vt_indices = tf.concat( [batch_vt_indices, tf.reshape(uv_image, [-1, 2])], axis=1 ) # careful diffuse_image = tf.reshape( tf.gather_nd(uv_rgb, batch_vt_indices), [batch_size, imageH, imageW, 3] ) uv_alphas = ( tf.reshape( tf.gather_nd(uv_mask[:, :, :, 0], batch_vt_indices), [batch_size, imageH, imageW, 1], ) * uv_alphas ) # Have shading para_light = para_illum background = ori_img rgb_images, shading_image = Shader.sh_shader( norm_image, uv_alphas, background, para_light, diffuse_image ) ori_img_remove_shading = ori_img / shading_image diffuse_image = tf.clip_by_value(diffuse_image, 0, 1) rgb_images = tf.clip_by_value(rgb_images, 0, 1) uv_attrs_image = tf.clip_by_value(uv_alphas, 0, 1) ori_img_remove_shading = tf.clip_by_value(ori_img_remove_shading, 0, 1) render_image = rgb_images render_image = render_image * uv_attrs_image + ori_img * (1 - uv_attrs_image) return render_image, uv_attrs_image, ori_img_remove_shading
def warp_ver_to_uv( v_attrs, # tensor [batch_size, N, 3] tri_v, # tensor tri_vt, # tensor vt_list, # tensor uv_size, ): # int if len(v_attrs.shape) != 3: raise ValueError( "v_attrs must have shape [batch_size, vertex_count, ?].") if len(tri_v.shape) != 2: raise ValueError("tri_v must have shape [triangles, 3].") if len(tri_vt.shape) != 2: raise ValueError("tri_vt must have shape [triangles, 3].") if len(vt_list.shape) != 2: raise ValueError("vt_list must have shape [vertex_texture_count, 2].") if tri_vt.dtype != tf.int32: raise ValueError("tri_vt must be of type int32.") if tri_v.dtype != tf.int32: raise ValueError("tri_v must be of type int32.") if v_attrs.dtype != tf.float32: raise ValueError("v_attrs must be of type float32.") if vt_list.dtype != tf.float32: raise ValueError("vt_list must be of type float32.") # add sample indices to tri_v and tri_vt batch_size, v_cnt, n_channels = v_attrs.shape n_tri = tri_v.shape[0] sample_indices = tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), axis=1), [1, n_tri * 3]), [-1], name="sample_indices", ) tri_v_list = tf.concat([tf.reshape(tri_v, [-1])] * batch_size, axis=0, name="tri_v_list") tri_vt_list = tf.concat([tf.reshape(tri_vt, [-1])] * batch_size, axis=0, name="tri_vt_list") tri_v_list = tf.stack([sample_indices, tri_v_list], axis=1, name="sample_tri_v_list") tri_vt_list = tf.stack([sample_indices, tri_vt_list], axis=1, name="sample_tri_vt_list") # gather vertex attributes v_attrs_list = tf.gather_nd(v_attrs, tri_v_list) v_attrs_count = tf.ones(dtype=tf.float32, shape=[batch_size * n_tri * 3, 1]) assert (len(v_attrs_list.shape) == 2 and v_attrs_list.shape[0].value == tri_v_list.shape[0].value) # add sample indices to vt_list n_vt = vt_list.shape[0].value vt_attrs_list = tf.scatter_nd( tri_vt_list, v_attrs_list, shape=[batch_size, n_vt, v_attrs.shape[2].value], name="vt_attrs_list", ) vt_attrs_count = tf.scatter_nd(tri_vt_list, v_attrs_count, shape=[batch_size, n_vt, 1], name="vt_attrs_count") vt_attrs_list = tf.div(vt_attrs_list, vt_attrs_count) assert len(vt_list.shape) == 2 and vt_list.shape[1].value == 2 u, v = tf.split(vt_list, 2, axis=1) z = tf.random_normal(shape=[n_vt, 1], stddev=0.000001) vt_list = tf.concat([(u * 2 - 1), ((1 - v) * 2 - 1), z], axis=1, name="full_vt") vt_list = tf.stack([vt_list] * batch_size, axis=0) # scatter vertex texture attributes renders, _ = rasterize.rasterize_clip_space( vt_list, vt_attrs_list, tri_vt, uv_size, uv_size, [-1] * vt_attrs_list.shape[2].value, ) renders.set_shape((batch_size, uv_size, uv_size, n_channels)) # renders = tf.clip_by_value(renders, 0, 1) # renders = renders[0] return renders