def testRowPartitionConstructionErrors(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) with self.assertRaisesRegex(ValueError, 'RowPartition constructor is private'): RowPartition(row_splits=row_splits) with self.assertRaisesRegex( TypeError, 'Row-partitioning argument must be a Tensor'): RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(ValueError, r'Shape \(6, 1\) must have rank 1'): RowPartition(row_splits=array_ops.expand_dims(row_splits, 1), internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(TypeError, 'Cached value must be a Tensor or None.'): RowPartition(row_splits=row_splits, row_lengths=[2, 3, 4], internal=row_partition._row_partition_factory_key) with self.assertRaisesRegex(ValueError, 'Inconsistent dtype'): RowPartition(row_splits=constant_op.constant([0, 3], dtypes.int64), nrows=constant_op.constant(1, dtypes.int32), internal=row_partition._row_partition_factory_key)
def _get_specified_row_partition(): """Needed for merge_with_spec tests. Normally, nvals isn't set.""" return RowPartition( row_splits=constant_op.constant([0, 3, 8], dtype=dtypes.int64), nrows=constant_op.constant(2, dtype=dtypes.int64), nvals=constant_op.constant(8, dtype=dtypes.int64), internal=row_partition._row_partition_factory_key)
def testRaggedTensorConstructionErrors(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) with self.assertRaisesRegexp(ValueError, 'RaggedTensor constructor is private'): RowPartition(row_splits=row_splits) with self.assertRaisesRegexp( TypeError, 'Row-partitioning argument must be a Tensor'): RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=True) with self.assertRaisesRegexp(ValueError, r'Shape \(6, 1\) must have rank 1'): RowPartition(row_splits=array_ops.expand_dims(row_splits, 1), internal=True) with self.assertRaisesRegexp(TypeError, 'Cached value must be a Tensor or None.'): RowPartition(row_splits=row_splits, cached_row_lengths=[2, 3, 4], internal=True)
def testRaggedTensorConstruction(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) rp = RowPartition( row_splits=row_splits, internal=row_partition._row_partition_factory_key) self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7])
def testRaggedTensorConstruction(self): row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) rt = RowPartition(row_splits=row_splits, internal=True) self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])