Пример #1
0
    def testRowPartitionConstructionErrors(self):
        row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)

        with self.assertRaisesRegex(ValueError,
                                    'RowPartition constructor is private'):
            RowPartition(row_splits=row_splits)

        with self.assertRaisesRegex(
                TypeError, 'Row-partitioning argument must be a Tensor'):
            RowPartition(row_splits=[0, 2, 2, 5, 6, 7],
                         internal=row_partition._row_partition_factory_key)

        with self.assertRaisesRegex(ValueError,
                                    r'Shape \(6, 1\) must have rank 1'):
            RowPartition(row_splits=array_ops.expand_dims(row_splits, 1),
                         internal=row_partition._row_partition_factory_key)

        with self.assertRaisesRegex(TypeError,
                                    'Cached value must be a Tensor or None.'):
            RowPartition(row_splits=row_splits,
                         row_lengths=[2, 3, 4],
                         internal=row_partition._row_partition_factory_key)

        with self.assertRaisesRegex(ValueError, 'Inconsistent dtype'):
            RowPartition(row_splits=constant_op.constant([0, 3], dtypes.int64),
                         nrows=constant_op.constant(1, dtypes.int32),
                         internal=row_partition._row_partition_factory_key)
Пример #2
0
def _get_specified_row_partition():
  """Needed for merge_with_spec tests. Normally, nvals isn't set."""
  return RowPartition(
      row_splits=constant_op.constant([0, 3, 8], dtype=dtypes.int64),
      nrows=constant_op.constant(2, dtype=dtypes.int64),
      nvals=constant_op.constant(8, dtype=dtypes.int64),
      internal=row_partition._row_partition_factory_key)
Пример #3
0
    def testRaggedTensorConstructionErrors(self):
        row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)

        with self.assertRaisesRegexp(ValueError,
                                     'RaggedTensor constructor is private'):
            RowPartition(row_splits=row_splits)

        with self.assertRaisesRegexp(
                TypeError, 'Row-partitioning argument must be a Tensor'):
            RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=True)

        with self.assertRaisesRegexp(ValueError,
                                     r'Shape \(6, 1\) must have rank 1'):
            RowPartition(row_splits=array_ops.expand_dims(row_splits, 1),
                         internal=True)

        with self.assertRaisesRegexp(TypeError,
                                     'Cached value must be a Tensor or None.'):
            RowPartition(row_splits=row_splits,
                         cached_row_lengths=[2, 3, 4],
                         internal=True)
Пример #4
0
 def testRaggedTensorConstruction(self):
   row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
   rp = RowPartition(
       row_splits=row_splits,
       internal=row_partition._row_partition_factory_key)
   self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7])
Пример #5
0
 def testRaggedTensorConstruction(self):
     row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
     rt = RowPartition(row_splits=row_splits, internal=True)
     self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])