コード例 #1
0
 def testUnknownRankError(self):
     if context.executing_eagerly():
         return
     partitions = array_ops.placeholder(dtypes.int32, None)
     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                 'partitions must have known rank'):
         ragged_array_ops.stack_dynamic_partitions(['a', 'b', 'c'],
                                                   partitions, 10)
コード例 #2
0
def _ragged_dynamic_partition(data, partitions, num_partitions, name=None):
    """RaggedTensor Dispatch override for tf.dynamic_partition."""
    if not isinstance(num_partitions, int) or num_partitions < 0:
        raise TypeError('num_partitions must be a non-negative integer')
    result = ragged_array_ops.stack_dynamic_partitions(data, partitions,
                                                       num_partitions, name)
    return [result[i] for i in range(num_partitions)]
コード例 #3
0
    def testRaggedSegmentStack(self,
                               data,
                               partitions,
                               num_partitions,
                               expected,
                               data_ragged_rank=None,
                               segment_ids_ragged_rank=None,
                               expected_ragged_rank=None):
        for seg_dtype in [dtypes.int32, dtypes.int64]:
            data_tensor = ragged_factory_ops.constant(
                data, row_splits_dtype=seg_dtype, ragged_rank=data_ragged_rank)
            segment_ids_tensor = ragged_factory_ops.constant(
                partitions,
                dtype=seg_dtype,
                row_splits_dtype=seg_dtype,
                ragged_rank=segment_ids_ragged_rank)
            expected_tensor = ragged_factory_ops.constant(
                expected,
                row_splits_dtype=seg_dtype,
                ragged_rank=expected_ragged_rank)
            result = ragged_array_ops.stack_dynamic_partitions(
                data_tensor, segment_ids_tensor, num_partitions)
            self.assertAllEqual(result, expected_tensor)

            # Check that it's equivalent to tf.stack(dynamic_partition(...)),
            # where applicable.
            if (data_ragged_rank == 0 and segment_ids_ragged_rank == 0
                    and seg_dtype == dtypes.int32):
                equiv = ragged_concat_ops.stack(
                    data_flow_ops.dynamic_partition(data_tensor,
                                                    segment_ids_tensor,
                                                    num_partitions))
                self.assertAllEqual(result, self.evaluate(equiv).to_list())
コード例 #4
0
 def testRuntimeError(self, data, partitions, num_partitions, error):
     data = ragged_factory_ops.constant(data)
     partitions = ragged_factory_ops.constant(partitions,
                                              dtype=dtypes.int64)
     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                 error):
         self.evaluate(
             ragged_array_ops.stack_dynamic_partitions(
                 data, partitions, num_partitions))
コード例 #5
0
 def testStaticError(self, data, partitions, num_partitions, error):
     with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
                                 error):
         ragged_array_ops.stack_dynamic_partitions(data, partitions,
                                                   num_partitions)