def testUnknownRankError(self): if context.executing_eagerly(): return partitions = array_ops.placeholder(dtypes.int32, None) with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 'partitions must have known rank'): ragged_array_ops.stack_dynamic_partitions(['a', 'b', 'c'], partitions, 10)
def _ragged_dynamic_partition(data, partitions, num_partitions, name=None): """RaggedTensor Dispatch override for tf.dynamic_partition.""" if not isinstance(num_partitions, int) or num_partitions < 0: raise TypeError('num_partitions must be a non-negative integer') result = ragged_array_ops.stack_dynamic_partitions(data, partitions, num_partitions, name) return [result[i] for i in range(num_partitions)]
def testRaggedSegmentStack(self, data, partitions, num_partitions, expected, data_ragged_rank=None, segment_ids_ragged_rank=None, expected_ragged_rank=None): for seg_dtype in [dtypes.int32, dtypes.int64]: data_tensor = ragged_factory_ops.constant( data, row_splits_dtype=seg_dtype, ragged_rank=data_ragged_rank) segment_ids_tensor = ragged_factory_ops.constant( partitions, dtype=seg_dtype, row_splits_dtype=seg_dtype, ragged_rank=segment_ids_ragged_rank) expected_tensor = ragged_factory_ops.constant( expected, row_splits_dtype=seg_dtype, ragged_rank=expected_ragged_rank) result = ragged_array_ops.stack_dynamic_partitions( data_tensor, segment_ids_tensor, num_partitions) self.assertAllEqual(result, expected_tensor) # Check that it's equivalent to tf.stack(dynamic_partition(...)), # where applicable. if (data_ragged_rank == 0 and segment_ids_ragged_rank == 0 and seg_dtype == dtypes.int32): equiv = ragged_concat_ops.stack( data_flow_ops.dynamic_partition(data_tensor, segment_ids_tensor, num_partitions)) self.assertAllEqual(result, self.evaluate(equiv).to_list())
def testRuntimeError(self, data, partitions, num_partitions, error): data = ragged_factory_ops.constant(data) partitions = ragged_factory_ops.constant(partitions, dtype=dtypes.int64) with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), error): self.evaluate( ragged_array_ops.stack_dynamic_partitions( data, partitions, num_partitions))
def testStaticError(self, data, partitions, num_partitions, error): with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), error): ragged_array_ops.stack_dynamic_partitions(data, partitions, num_partitions)