def test2DSpanOverlaps(self, name, source_start, source_limit, target_start, target_limit, expected_overlap_pairs, contains=False, contained_by=False, partial_overlap=False, ragged_rank=None): # Assemble expected value. expected = [[[(b, s, t) in expected_overlap_pairs for t in range(len(target_limit[b]))] for s in range(len(source_limit[b]))] for b in range(self.BATCH_SIZE)] source_start = tf.ragged.constant(source_start, ragged_rank=ragged_rank) source_limit = tf.ragged.constant(source_limit, ragged_rank=ragged_rank) target_start = tf.ragged.constant(target_start, ragged_rank=ragged_rank) target_limit = tf.ragged.constant(target_limit, ragged_rank=ragged_rank) overlaps = pointer_ops.span_overlaps(source_start, source_limit, target_start, target_limit, contains, contained_by, partial_overlap) self.assertRaggedEqual(overlaps, expected)
def test3DSpanOverlaps(self, name, source_start, source_limit, target_start, target_limit, expected_overlap_pairs, contains=False, contained_by=False, partial_overlap=False, ragged_rank=None): # Assemble expected value. # pylint: disable=g-complex-comprehension expected = [[[[(b1, b2, s, t) in expected_overlap_pairs for t in range(len(target_limit[b1][b2]))] for s in range(len(source_limit[b1][b2]))] for b2 in range(len(source_limit[b1]))] for b1 in range(2)] source_start = ragged_factory_ops.constant( source_start, ragged_rank=ragged_rank) source_limit = ragged_factory_ops.constant( source_limit, ragged_rank=ragged_rank) target_start = ragged_factory_ops.constant( target_start, ragged_rank=ragged_rank) target_limit = ragged_factory_ops.constant( target_limit, ragged_rank=ragged_rank) overlaps = pointer_ops.span_overlaps(source_start, source_limit, target_start, target_limit, contains, contained_by, partial_overlap) self.assertAllEqual(overlaps, expected)
def test1DSpanOverlaps(self, name, source_start, source_limit, target_start, target_limit, expected_overlap_pairs, contains=False, contained_by=False, partial_overlap=False): # Assemble expected value. (Writing out the complete expected result # matrix takes up a lot of space, so instead we just list the positions # in the matrix that should be True.) expected = [[(s, t) in expected_overlap_pairs for t in range(len(target_limit))] for s in range(len(source_limit))] overlaps = pointer_ops.span_overlaps(source_start, source_limit, target_start, target_limit, contains, contained_by, partial_overlap) self.assertRaggedEqual(overlaps, expected)
def testErrors(self): t = [10, 20, 30, 40, 50] with self.assertRaisesRegexp(TypeError, 'contains must be bool.'): pointer_ops.span_overlaps(t, t, t, t, contains='x') with self.assertRaisesRegexp(TypeError, 'contained_by must be bool.'): pointer_ops.span_overlaps(t, t, t, t, contained_by='x') with self.assertRaisesRegexp(TypeError, 'partial_overlap must be bool.'): pointer_ops.span_overlaps(t, t, t, t, partial_overlap='x') with self.assertRaisesRegexp( TypeError, 'source_start, source_limit, target_start, and ' 'target_limit must all have the same dtype'): pointer_ops.span_overlaps(t, t, t, [1.0, 2.0, 3.0, 4.0, 5.0]) with self.assertRaisesRegexp(ValueError, r'Shapes \(5,\) and \(4,\) are incompatible'): pointer_ops.span_overlaps(t, t[:4], t, t) with self.assertRaisesRegexp(ValueError, r'Shapes \(4,\) and \(5,\) are incompatible'): pointer_ops.span_overlaps(t, t, t[:4], t) with self.assertRaisesRegexp( ValueError, r'Shapes \(1, 5\) and \(5,\) must have the same rank'): pointer_ops.span_overlaps([t], [t], t, t) if not context.executing_eagerly(): with self.assertRaisesRegexp( ValueError, 'For ragged inputs, the shape.ndims of at least one ' 'span tensor must be statically known.'): x = ragged_tensor.RaggedTensor.from_row_splits( array_ops.placeholder(dtypes.int32), [0, 3, 8]) pointer_ops.span_overlaps(x, x, x, x) with self.assertRaisesRegexp( ValueError, 'Span tensors must all have the same ragged_rank'): a = [[10, 20, 30], [40, 50, 60]] pointer_ops.span_overlaps(a, a, a, ragged_factory_ops.constant(a)) with self.assertRaisesRegexp( errors.InvalidArgumentError, 'Mismatched ragged shapes for batch dimensions'): rt1 = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5]]]) rt2 = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5], [6]]]) pointer_ops.span_overlaps(rt1, rt1, rt2, rt2)