コード例 #1
0
 def testRaggedCountSparseOutputNegativeValue(self):
   splits = [0, 4, 7]
   values = [1, 1, 2, 1, -2, 10, 5]
   with self.assertRaisesRegex(errors.InvalidArgumentError,
                               "Input values must all be non-negative"):
     self.evaluate(
         gen_count_ops.RaggedCountSparseOutput(
             splits=splits, values=values, binary_output=False))
コード例 #2
0
 def testRaggedCountSparseOutputBadSplitsStart(self):
     splits = [1, 7]
     values = [1, 1, 2, 1, 2, 10, 5]
     weights = [1, 2, 3, 4, 5, 6, 7]
     with self.assertRaisesRegex(errors.InvalidArgumentError,
                                 "Splits must start with 0"):
         self.evaluate(
             gen_count_ops.RaggedCountSparseOutput(splits=splits,
                                                   values=values,
                                                   weights=weights,
                                                   binary_output=False))
コード例 #3
0
 def testRaggedCountSparseOutput(self):
   splits = [0, 4, 7]
   values = [1, 1, 2, 1, 2, 10, 5]
   weights = [1, 2, 3, 4, 5, 6, 7]
   output_indices, output_values, output_shape = self.evaluate(
       gen_count_ops.RaggedCountSparseOutput(
           splits=splits, values=values, weights=weights, binary_output=False))
   self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
                       output_indices)
   self.assertAllEqual([7, 3, 5, 7, 6], output_values)
   self.assertAllEqual([2, 11], output_shape)
コード例 #4
0
 def testRaggedCountSparseOutputEmptySplits(self):
     splits = []
     values = [1, 1, 2, 1, 2, 10, 5]
     weights = [1, 2, 3, 4, 5, 6, 7]
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Must provide at least 2 elements for the splits argument"):
         self.evaluate(
             gen_count_ops.RaggedCountSparseOutput(splits=splits,
                                                   values=values,
                                                   weights=weights,
                                                   binary_output=False))
コード例 #5
0
 def testRaggedCountSparseOutputBadWeightsShape(self):
     splits = [0, 4, 7]
     values = [1, 1, 2, 1, 2, 10, 5]
     weights = [1, 2, 3, 4, 5, 6]
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Weights and values must have the same shape"):
         self.evaluate(
             gen_count_ops.RaggedCountSparseOutput(splits=splits,
                                                   values=values,
                                                   weights=weights,
                                                   binary_output=False))