示例#1
0
 def test_scatter_to_2d_not_sorted(self):
   with tf.Graph().as_default():
     segments = [0, 0, 1, 0, 2, 2]
     tensor = [0, 0, 1, 0, 2, 2]
     pad_value = -1
     expected = [[0, 0, 0], [1, -1, -1], [2, 2, -1]]
     with tf.compat.v1.Session() as sess:
       tensor_2d = sess.run(utils.scatter_to_2d(tensor, segments, pad_value))
       self.assertAllEqual(tensor_2d, expected)
示例#2
0
 def test_scatter_to_2d_with_smaller_output_shape(self):
     with tf.Graph().as_default():
         segments = [0, 0, 0, 1, 2, 2]
         tensor = [1, 0, 0, 1, 2, 2]
         pad_value = -1
         smaller_output_shape = [2, 2]
         smaller_expected = [
             [1, 0],
             [1, -1],
         ]
         with tf.compat.v1.Session() as sess:
             tensor_2d = sess.run(
                 utils.scatter_to_2d(tensor, segments, pad_value,
                                     smaller_output_shape))
             self.assertAllEqual(tensor_2d, smaller_expected)