def testRaggedCountSparseOutputNegativeValue(self): splits = [0, 4, 7] values = [1, 1, 2, 1, -2, 10, 5] with self.assertRaisesRegex(errors.InvalidArgumentError, "Input values must all be non-negative"): self.evaluate( gen_count_ops.RaggedCountSparseOutput( splits=splits, values=values, binary_output=False))
def testRaggedCountSparseOutputBadSplitsStart(self): splits = [1, 7] values = [1, 1, 2, 1, 2, 10, 5] weights = [1, 2, 3, 4, 5, 6, 7] with self.assertRaisesRegex(errors.InvalidArgumentError, "Splits must start with 0"): self.evaluate( gen_count_ops.RaggedCountSparseOutput(splits=splits, values=values, weights=weights, binary_output=False))
def testRaggedCountSparseOutput(self): splits = [0, 4, 7] values = [1, 1, 2, 1, 2, 10, 5] weights = [1, 2, 3, 4, 5, 6, 7] output_indices, output_values, output_shape = self.evaluate( gen_count_ops.RaggedCountSparseOutput( splits=splits, values=values, weights=weights, binary_output=False)) self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]], output_indices) self.assertAllEqual([7, 3, 5, 7, 6], output_values) self.assertAllEqual([2, 11], output_shape)
def testRaggedCountSparseOutputEmptySplits(self): splits = [] values = [1, 1, 2, 1, 2, 10, 5] weights = [1, 2, 3, 4, 5, 6, 7] with self.assertRaisesRegex( errors.InvalidArgumentError, "Must provide at least 2 elements for the splits argument"): self.evaluate( gen_count_ops.RaggedCountSparseOutput(splits=splits, values=values, weights=weights, binary_output=False))
def testRaggedCountSparseOutputBadWeightsShape(self): splits = [0, 4, 7] values = [1, 1, 2, 1, 2, 10, 5] weights = [1, 2, 3, 4, 5, 6] with self.assertRaisesRegex( errors.InvalidArgumentError, "Weights and values must have the same shape"): self.evaluate( gen_count_ops.RaggedCountSparseOutput(splits=splits, values=values, weights=weights, binary_output=False))