def testDedupeDenseTensorPerRow(self, values, expected_indices, expected_output): dense_input = tf.constant(values) output_tensor = mappers.deduplicate_tensor_per_row(dense_input) with tf.compat.v1.Session(): output = output_tensor.eval() self.assertAllEqual(output.indices, expected_indices) self.assertAllEqual(output.values, expected_output)
def testDedupeSparseTensorPerRow(self, indices, values, dense_shape, expected_output_indices, expected_output_values, expected_output_shape): sp_input = tf.SparseTensor( indices=indices, values=values, dense_shape=dense_shape) output_tensor = mappers.deduplicate_tensor_per_row(sp_input) with tf.compat.v1.Session(): output = output_tensor.eval() self.assertAllEqual(output.indices, expected_output_indices) self.assertAllEqual(output.values, expected_output_values) self.assertAllEqual(output.dense_shape, expected_output_shape)
def testDedup3dInputRaises(self): dense_input = tf.constant([[[b'a', b'a'], [b'b', b'b']], [[b'a', b'a'], [b'd', b'd']]]) with self.assertRaises(ValueError): mappers.deduplicate_tensor_per_row(dense_input)