Ejemplo n.º 1
0
  def test2DSpanOverlaps(self,
                         name,
                         source_start,
                         source_limit,
                         target_start,
                         target_limit,
                         expected_overlap_pairs,
                         contains=False,
                         contained_by=False,
                         partial_overlap=False,
                         ragged_rank=None):
    # Assemble expected value.
    expected = [[[(b, s, t) in expected_overlap_pairs
                  for t in range(len(target_limit[b]))]
                 for s in range(len(source_limit[b]))]
                for b in range(self.BATCH_SIZE)]

    source_start = tf.ragged.constant(source_start, ragged_rank=ragged_rank)
    source_limit = tf.ragged.constant(source_limit, ragged_rank=ragged_rank)
    target_start = tf.ragged.constant(target_start, ragged_rank=ragged_rank)
    target_limit = tf.ragged.constant(target_limit, ragged_rank=ragged_rank)
    overlaps = pointer_ops.span_overlaps(source_start, source_limit,
                                         target_start, target_limit, contains,
                                         contained_by, partial_overlap)
    self.assertRaggedEqual(overlaps, expected)
Ejemplo n.º 2
0
  def test3DSpanOverlaps(self,
                         name,
                         source_start,
                         source_limit,
                         target_start,
                         target_limit,
                         expected_overlap_pairs,
                         contains=False,
                         contained_by=False,
                         partial_overlap=False,
                         ragged_rank=None):
    # Assemble expected value.
    # pylint: disable=g-complex-comprehension
    expected = [[[[(b1, b2, s, t) in expected_overlap_pairs
                   for t in range(len(target_limit[b1][b2]))]
                  for s in range(len(source_limit[b1][b2]))]
                 for b2 in range(len(source_limit[b1]))]
                for b1 in range(2)]

    source_start = ragged_factory_ops.constant(
        source_start, ragged_rank=ragged_rank)
    source_limit = ragged_factory_ops.constant(
        source_limit, ragged_rank=ragged_rank)
    target_start = ragged_factory_ops.constant(
        target_start, ragged_rank=ragged_rank)
    target_limit = ragged_factory_ops.constant(
        target_limit, ragged_rank=ragged_rank)
    overlaps = pointer_ops.span_overlaps(source_start, source_limit,
                                         target_start, target_limit, contains,
                                         contained_by, partial_overlap)
    self.assertAllEqual(overlaps, expected)
Ejemplo n.º 3
0
  def test1DSpanOverlaps(self,
                         name,
                         source_start,
                         source_limit,
                         target_start,
                         target_limit,
                         expected_overlap_pairs,
                         contains=False,
                         contained_by=False,
                         partial_overlap=False):
    # Assemble expected value.  (Writing out the complete expected result
    # matrix takes up a lot of space, so instead we just list the positions
    # in the matrix that should be True.)
    expected = [[(s, t) in expected_overlap_pairs
                 for t in range(len(target_limit))]
                for s in range(len(source_limit))]

    overlaps = pointer_ops.span_overlaps(source_start, source_limit,
                                         target_start, target_limit, contains,
                                         contained_by, partial_overlap)
    self.assertRaggedEqual(overlaps, expected)
Ejemplo n.º 4
0
  def testErrors(self):
    t = [10, 20, 30, 40, 50]

    with self.assertRaisesRegexp(TypeError, 'contains must be bool.'):
      pointer_ops.span_overlaps(t, t, t, t, contains='x')
    with self.assertRaisesRegexp(TypeError, 'contained_by must be bool.'):
      pointer_ops.span_overlaps(t, t, t, t, contained_by='x')
    with self.assertRaisesRegexp(TypeError, 'partial_overlap must be bool.'):
      pointer_ops.span_overlaps(t, t, t, t, partial_overlap='x')
    with self.assertRaisesRegexp(
        TypeError, 'source_start, source_limit, target_start, and '
        'target_limit must all have the same dtype'):
      pointer_ops.span_overlaps(t, t, t, [1.0, 2.0, 3.0, 4.0, 5.0])
    with self.assertRaisesRegexp(ValueError,
                                 r'Shapes \(5,\) and \(4,\) are incompatible'):
      pointer_ops.span_overlaps(t, t[:4], t, t)
    with self.assertRaisesRegexp(ValueError,
                                 r'Shapes \(4,\) and \(5,\) are incompatible'):
      pointer_ops.span_overlaps(t, t, t[:4], t)
    with self.assertRaisesRegexp(
        ValueError, r'Shapes \(1, 5\) and \(5,\) must have the same rank'):
      pointer_ops.span_overlaps([t], [t], t, t)
    if not context.executing_eagerly():
      with self.assertRaisesRegexp(
          ValueError, 'For ragged inputs, the shape.ndims of at least one '
          'span tensor must be statically known.'):
        x = ragged_tensor.RaggedTensor.from_row_splits(
            array_ops.placeholder(dtypes.int32), [0, 3, 8])
        pointer_ops.span_overlaps(x, x, x, x)
    with self.assertRaisesRegexp(
        ValueError, 'Span tensors must all have the same ragged_rank'):
      a = [[10, 20, 30], [40, 50, 60]]
      pointer_ops.span_overlaps(a, a, a, ragged_factory_ops.constant(a))
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'Mismatched ragged shapes for batch dimensions'):
      rt1 = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5]]])
      rt2 = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5], [6]]])
      pointer_ops.span_overlaps(rt1, rt1, rt2, rt2)