Exemplo n.º 1
0
 def testDedupeDenseTensorPerRow(self, values, expected_indices,
                                 expected_output):
     dense_input = tf.constant(values)
     output_tensor = mappers.deduplicate_tensor_per_row(dense_input)
     with tf.compat.v1.Session():
         output = output_tensor.eval()
         self.assertAllEqual(output.indices, expected_indices)
         self.assertAllEqual(output.values, expected_output)
Exemplo n.º 2
0
 def testDedupeSparseTensorPerRow(self, indices, values, dense_shape,
                                  expected_output_indices,
                                  expected_output_values,
                                  expected_output_shape):
   sp_input = tf.SparseTensor(
       indices=indices, values=values, dense_shape=dense_shape)
   output_tensor = mappers.deduplicate_tensor_per_row(sp_input)
   with tf.compat.v1.Session():
     output = output_tensor.eval()
     self.assertAllEqual(output.indices, expected_output_indices)
     self.assertAllEqual(output.values, expected_output_values)
     self.assertAllEqual(output.dense_shape, expected_output_shape)
Exemplo n.º 3
0
 def testDedup3dInputRaises(self):
     dense_input = tf.constant([[[b'a', b'a'], [b'b', b'b']],
                                [[b'a', b'a'], [b'd', b'd']]])
     with self.assertRaises(ValueError):
         mappers.deduplicate_tensor_per_row(dense_input)