def mpi_resample_cube(mpi, tgt, intrinsics, depth_planes, side_length, cube_res): """Resample MPI onto cube centered at target point. Args: mpi: [B,H,W,D,C], input MPI tgt: [B,3], [x,y,z] coordinates for cube center (in reference/mpi frame) intrinsics: [B,3,3], MPI reference camera intrinsics depth_planes: [D] depth values for MPI planes side_length: metric side length of cube cube_res: resolution of each cube dimension Returns: resampled: [B, cube_res, cube_res, cube_res, C] """ batch_size = tf.shape(mpi)[0] num_depths = tf.shape(mpi)[3] # compute MPI world coordinates intrinsics_tile = tf.tile(intrinsics, [num_depths, 1, 1]) # create cube coordinates b_vals = tf.to_float(tf.range(batch_size)) x_vals = tf.linspace(-side_length / 2.0, side_length / 2.0, cube_res) y_vals = tf.linspace(-side_length / 2.0, side_length / 2.0, cube_res) z_vals = tf.linspace(side_length / 2.0, -side_length / 2.0, cube_res) b, y, x, z = tf.meshgrid(b_vals, y_vals, x_vals, z_vals, indexing='ij') x = x + tgt[:, 0, tf.newaxis, tf.newaxis, tf.newaxis] y = y + tgt[:, 1, tf.newaxis, tf.newaxis, tf.newaxis] z = z + tgt[:, 2, tf.newaxis, tf.newaxis, tf.newaxis] ones = tf.ones_like(x) coords = tf.stack([x, y, z, ones], axis=1) coords_r = tf.reshape( tf.transpose(coords, [0, 4, 1, 2, 3]), [batch_size * cube_res, 4, cube_res, cube_res]) # store elements with negative z vals for projection bad_inds = tf.less(z, 0.0) # project into reference camera to transform coordinates into MPI indices filler = tf.constant([0.0, 0.0, 0.0, 1.0], shape=[1, 1, 4]) filler = tf.tile(filler, [batch_size * cube_res, 1, 1]) intrinsics_tile = tf.tile(intrinsics, [cube_res, 1, 1]) intrinsics_tile_4 = tf.concat( [intrinsics_tile, tf.zeros([batch_size * cube_res, 3, 1])], axis=2) intrinsics_tile_4 = tf.concat([intrinsics_tile_4, filler], axis=1) coords_proj = cam2pixel(coords_r, intrinsics_tile_4) coords_depths = tf.transpose(coords_r[:, 2:3, :, :], [0, 2, 3, 1]) coords_depth_inds = (tf.to_float(num_depths) - 1) * ( (1.0 / coords_depths) - (1.0 / depth_planes[0])) / ((1.0 / depth_planes[-1]) - (1.0 / depth_planes[0])) coords_proj = tf.concat([coords_proj, coords_depth_inds], axis=3) coords_proj = tf.transpose( tf.reshape(coords_proj, [batch_size, cube_res, cube_res, cube_res, 3]), [0, 2, 3, 1, 4]) coords_proj = tf.concat([b[:, :, :, :, tf.newaxis], coords_proj], axis=4) # trilinear interpolation gather from MPI # interpolate pre-multiplied RGBAs, then un-pre-multiply mpi_alpha = mpi[Ellipsis, -1:] mpi_channels_p = mpi[Ellipsis, :-1] * mpi_alpha mpi_p = tf.concat([mpi_channels_p, mpi_alpha], axis=-1) resampled_p = sampling.trilerp_gather(mpi_p, coords_proj, bad_inds) resampled_alpha = tf.clip_by_value(resampled_p[Ellipsis, -1:], 0.0, 1.0) resampled_channels = resampled_p[Ellipsis, :-1] / (resampled_alpha + 1e-8) resampled = tf.concat([resampled_channels, resampled_alpha], axis=-1) return resampled, coords_proj
def spherical_cubevol_resample(vol, env2ref, cube_center, side_length, n_phi, n_theta, n_r): """Resample cube volume onto spherical coordinates centered at target point. Args: vol: [B,H,W,D,C], input volume env2ref: [B,4,4], relative pose transformation (transform env to ref) cube_center: [B,3], [x,y,z] coordinates for center of cube volume side_length: side length of cube n_phi: number of samples along vertical spherical coordinate dim n_theta: number of samples along horizontal spherical coordinate dim n_r: number of samples along radius spherical coordinate dim Returns: resampled: [B, n_phi, n_theta, n_r, C] """ batch_size = tf.shape(vol)[0] height = tf.shape(vol)[1] cube_res = tf.to_float(height) # create spherical coordinates b_vals = tf.to_float(tf.range(batch_size)) phi_vals = tf.linspace(0.0, np.pi, n_phi) theta_vals = tf.linspace(1.5 * np.pi, -0.5 * np.pi, n_theta) # compute radii to use x_vals = tf.linspace(-side_length / 2.0, side_length / 2.0, tf.to_int32(cube_res)) y_vals = tf.linspace(-side_length / 2.0, side_length / 2.0, tf.to_int32(cube_res)) z_vals = tf.linspace(side_length / 2.0, -side_length / 2.0, tf.to_int32(cube_res)) y_c, x_c, z_c = tf.meshgrid(y_vals, x_vals, z_vals, indexing='ij') x_c = x_c + cube_center[:, 0, tf.newaxis, tf.newaxis, tf.newaxis] y_c = y_c + cube_center[:, 1, tf.newaxis, tf.newaxis, tf.newaxis] z_c = z_c + cube_center[:, 2, tf.newaxis, tf.newaxis, tf.newaxis] cube_coords = tf.stack([x_c, y_c, z_c], axis=4) min_r = tf.reduce_min( tf.norm( cube_coords - env2ref[:, :3, 3][:, tf.newaxis, tf.newaxis, tf.newaxis, :], axis=4), axis=[0, 1, 2, 3]) # side_length / cube_res max_r = tf.reduce_max( tf.norm( cube_coords - env2ref[:, :3, 3][:, tf.newaxis, tf.newaxis, tf.newaxis, :], axis=4), axis=[0, 1, 2, 3]) r_vals = tf.linspace(max_r, min_r, n_r) b, phi, theta, r = tf.meshgrid( b_vals, phi_vals, theta_vals, r_vals, indexing='ij') # currently in env frame # transform spherical coordinates into cartesian # (currently in env frame, z points forwards) x = r * tf.cos(theta) * tf.sin(phi) z = r * tf.sin(theta) * tf.sin(phi) y = r * tf.cos(phi) # transform coordinates into ref frame sphere_coords = tf.stack([x, y, z, tf.ones_like(x)], axis=-1)[Ellipsis, tf.newaxis] sphere_coords_ref = tfmm(env2ref, sphere_coords) x = sphere_coords_ref[Ellipsis, 0, 0] y = sphere_coords_ref[Ellipsis, 1, 0] z = sphere_coords_ref[Ellipsis, 2, 0] # transform coordinates into vol indices x_inds = (x - cube_center[:, 0, tf.newaxis, tf.newaxis, tf.newaxis] + side_length / 2.0) * ((cube_res - 1) / side_length) y_inds = -(y - cube_center[:, 1, tf.newaxis, tf.newaxis, tf.newaxis] - side_length / 2.0) * ((cube_res - 1) / side_length) z_inds = -(z - cube_center[:, 2, tf.newaxis, tf.newaxis, tf.newaxis] - side_length / 2.0) * ((cube_res - 1) / side_length) sphere_coords_inds = tf.stack([b, x_inds, y_inds, z_inds], axis=-1) # trilinear interpolation gather from volume # interpolate pre-multiplied RGBAs, then un-pre-multiply vol_alpha = tf.clip_by_value(vol[Ellipsis, -1:], 0.0, 1.0) vol_channels_p = vol[Ellipsis, :-1] * vol_alpha vol_p = tf.concat([vol_channels_p, vol_alpha], axis=-1) resampled_p = sampling.trilerp_gather(vol_p, sphere_coords_inds) resampled_alpha = resampled_p[Ellipsis, -1:] resampled_channels = resampled_p[Ellipsis, :-1] / (resampled_alpha + 1e-8) resampled = tf.concat([resampled_channels, resampled_alpha], axis=-1) return resampled, r_vals