def test_tf2_benchmark(self): with tf.device('/GPU:0'): grid = tf.convert_to_tensor(self.grid) guide = tf.convert_to_tensor(self.guide) f = lambda: tf2_ops.bilateral_slice(grid, guide).numpy() mean_elapsed_ms, elapsed_ms = _timeit(f, self.burn_iterations, self.benchmark_iterations) print( f'TF2 batched bilateral_slice took {mean_elapsed_ms} ms per iteration, {elapsed_ms} ms total' )
def run_grad_test(self, batch_size, h, w, input_channels, gh, gw, gd, output_channels, grad_tensor_name, use_gpu): dev = _get_device_string(use_gpu) gc = (1 + input_channels) * output_channels grid_shape = [batch_size, gh, gw, gd, gc] guide_shape = [batch_size, h, w] output_shape = [batch_size, h, w, gc] grid_data = np.random.rand(*grid_shape).astype(np.float32) guide_data = np.random.rand(*guide_shape).astype(np.float32) tf.reset_default_graph() graph = tf.Graph() with graph.as_default(): with tf.device(dev): grid_tensor = tf.convert_to_tensor(grid_data, name='grid', dtype=tf.float32) guide_tensor = tf.convert_to_tensor(guide_data, name='guide', dtype=tf.float32) output_tensor = ops.bilateral_slice(grid_tensor, guide_tensor) if grad_tensor_name == 'grid': grad_tensor = grid_tensor grad_shape = grid_shape elif grad_tensor_name == 'guide': grad_tensor = guide_tensor grad_shape = guide_shape # It is important to use self.test_session, which will disable the # graph optimization, otherwise it won't use GPU ops. See details here: # https://github.com/tensorflow/tensorflow/issues/2054 with self.test_session(graph=graph, use_gpu=use_gpu, force_gpu=use_gpu): err = tf.test.compute_gradient_error(grad_tensor, grad_shape, output_tensor, output_shape, delta=1e-4) # Note that the gradient cannot be accurate, as trilinear interpolation # is not a smooth function. When the interpolated point is on the grid, # the gradient does not exist. Therefore, the analytical gradient (by # gradient op, implemented in bilateral_slice.cu.cc) and numerical # grident (by tf.test.compute_gradient_error) will never match. self.assertLess(err, 3e-3)
def run_bilateral_slice(self, grid_data, guide_data, use_gpu): dev = _get_device_string(use_gpu) graph = tf.Graph() with graph.as_default(): with tf.device(dev): grid_tensor = tf.convert_to_tensor(grid_data, name='grid', dtype=tf.float32) guide_tensor = tf.convert_to_tensor(guide_data, name='guide', dtype=tf.float32) output_tensor = ops.bilateral_slice(grid_tensor, guide_tensor) with self.test_session(graph=graph, use_gpu=use_gpu, force_gpu=use_gpu) as sess: output_data = sess.run(output_tensor) return output_data, output_tensor
def bilateral_slice(grid, guide, name=None): """Slices into a bilateral grid using the guide map. Args: grid: (Tensor) [batch_size, grid_h, grid_w, depth, n_outputs] grid to slice from. guide: (Tensor) [batch_size, h, w ] guide map to slice along. name: (string) name for the operation. Returns: sliced: (Tensor) [batch_size, h, w, n_outputs] sliced output. """ with tf.name_scope(name): gridshape = grid.get_shape().as_list() if len(gridshape) == 6: _, _, _, _, n_out, n_in = gridshape grid = tf.concat(tf.unstack(grid, None, axis=5), 4) sliced = hdrnet_ops.bilateral_slice(grid, guide) if len(gridshape) == 6: sliced = tf.stack(tf.split(sliced, n_in, axis=3), axis=4) return sliced
def test_bilateral_slice_jax_close_to_tf2(self): batch_size = 4 gh = 16 gw = 12 gd = 8 gc = 2 h = 640 w = 480 grid_shape = (batch_size, gh, gw, gd, gc) guide_shape = (batch_size, h, w) expected_output_shape = (batch_size, h, w, gc) grid = np.random.rand(*grid_shape).astype(np.float32) guide = np.random.rand(*guide_shape).astype(np.float32) tf2_sliced = tf2_ops.bilateral_slice(grid, guide).numpy() jax_sliced = self.jax_batch_slice(grid, guide) self.assertTupleEqual(tf2_sliced.shape, expected_output_shape) self.assertTupleEqual(jax_sliced.shape, expected_output_shape) self.assertAllClose(tf2_sliced, jax_sliced)