def batch_displacement_warp3d(imgs, vector_fields): """ warp images by displacement vector fields Parameters ---------- imgs : tf.Tensor images to be warped [n_batch, xlen, ylen, zlen, n_channel] vector_fields : tf.Tensor [n_batch, 3, xlen, ylen, zlen] Returns ------- output : tf.Tensor warped imagees [n_batch, xlen, ylen, zlen, n_channel] """ n_batch = tf.shape(imgs)[0] xlen = tf.shape(imgs)[1] ylen = tf.shape(imgs)[2] zlen = tf.shape(imgs)[3] grids = batch_mgrid(n_batch, xlen, ylen, zlen) T_g = grids + vector_fields output = batch_warp3d(imgs, T_g) return output
def batch_affine_warp3d(imgs, theta): """ affine transforms 3d images Parameters ---------- imgs : tf.Tensor images to be warped [n_batch, xlen, ylen, zlen, n_channel] theta : tf.Tensor parameters of affine transformation [n_batch, 12] Returns ------- output : tf.Tensor warped images [n_batch, xlen, ylen, zlen, n_channel] """ n_batch = tf.shape(imgs)[0] xlen = tf.shape(imgs)[1] ylen = tf.shape(imgs)[2] zlen = tf.shape(imgs)[3] theta = tf.reshape(theta, [-1, 3, 4]) matrix = tf.slice(theta, [0, 0, 0], [-1, -1, 3]) t = tf.slice(theta, [0, 0, 3], [-1, -1, -1]) grids = batch_mgrid(n_batch, xlen, ylen, zlen) grids = tf.reshape(grids, [n_batch, 3, -1]) T_g = tf.batch_matmul(matrix, grids) + t T_g = tf.reshape(T_g, [n_batch, 3, xlen, ylen, zlen]) output = batch_warp3d(imgs, T_g) return output