def graph_fn():
   (y_grid, x_grid) = ta_utils.image_shape_to_grids(height=3, width=5)
   y_coordinates = tf.constant([1.5, 0.5], dtype=tf.float32)
   x_coordinates = tf.constant([2.5, 4.5], dtype=tf.float32)
   sigma = tf.constant([0.1, 0.5], dtype=tf.float32)
   channel_onehot = tf.constant([[1, 0, 0], [0, 1, 0]], dtype=tf.float32)
   channel_weights = tf.constant([1, 1], dtype=tf.float32)
   heatmap = ta_utils.coordinates_to_heatmap(y_grid, x_grid, y_coordinates,
                                             x_coordinates, sigma,
                                             channel_onehot, channel_weights)
   return heatmap
예제 #2
0
 def test_coordinates_to_heatmap(self):
   (y_grid, x_grid) = ta_utils.image_shape_to_grids(height=3, width=5)
   y_coordinates = tf.constant([1.5, 0.5], dtype=tf.float32)
   x_coordinates = tf.constant([2.5, 4.5], dtype=tf.float32)
   sigma = tf.constant([0.1, 0.5], dtype=tf.float32)
   channel_onehot = tf.constant([[1, 0, 0], [0, 1, 0]], dtype=tf.float32)
   channel_weights = tf.constant([1, 1], dtype=tf.float32)
   heatmap = ta_utils.coordinates_to_heatmap(y_grid, x_grid, y_coordinates,
                                             x_coordinates, sigma,
                                             channel_onehot, channel_weights)
   # Peak at (1, 2) for the first class.
   self.assertAlmostEqual(1.0, heatmap.numpy()[1, 2, 0])
   # Peak at (0, 4) for the second class.
   self.assertAlmostEqual(1.0, heatmap.numpy()[0, 4, 1])