예제 #1
0
 def _set_union(self, a, b):
     # Validate that we get the same results with or without `validate_indices`,
     # and with a & b swapped.
     ops = (
         sets.set_union(a, b, validate_indices=True),
         sets.set_union(a, b, validate_indices=False),
         sets.set_union(b, a, validate_indices=True),
         sets.set_union(b, a, validate_indices=False),
     )
     for op in ops:
         self._assert_static_shapes(a, op)
     return self._run_equivalent_set_ops(ops)
예제 #2
0
 def _set_union(self, a, b):
   # Validate that we get the same results with or without `validate_indices`,
   # and with a & b swapped.
   ops = (
       sets.set_union(
           a, b, validate_indices=True),
       sets.set_union(
           a, b, validate_indices=False),
       sets.set_union(
           b, a, validate_indices=True),
       sets.set_union(
           b, a, validate_indices=False),)
   for op in ops:
     self._assert_static_shapes(a, op)
   return self._run_equivalent_set_ops(ops)
예제 #3
0
 def _set_union(self, a, b):
     # Validate that we get the same results with or without `validate_indices`,
     # and with a & b swapped.
     ops = (
         sets.set_union(a, b, validate_indices=True),
         sets.set_union(a, b, validate_indices=False),
         sets.set_union(b, a, validate_indices=True),
         sets.set_union(b, a, validate_indices=False),
     )
     for op in ops:
         self._assert_shapes(a, op)
     with self.test_session() as sess:
         results = sess.run(ops)
     for i in range(1, 4):
         self.assertAllEqual(results[0].indices, results[i].indices)
         self.assertAllEqual(results[0].values, results[i].values)
         self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
     return results[0]
예제 #4
0
 def _set_union(self, a, b):
   # Validate that we get the same results with or without `validate_indices`,
   # and with a & b swapped.
   ops = (
       sets.set_union(
           a, b, validate_indices=True),
       sets.set_union(
           a, b, validate_indices=False),
       sets.set_union(
           b, a, validate_indices=True),
       sets.set_union(
           b, a, validate_indices=False),)
   for op in ops:
     self._assert_shapes(a, op)
   with self.test_session() as sess:
     results = sess.run(ops)
   for i in range(1, 4):
     self.assertAllEqual(results[0].indices, results[i].indices)
     self.assertAllEqual(results[0].values, results[i].values)
     self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
   return results[0]
예제 #5
0
    def test_set_union_output_is_sorted(self, dtype):
        # We don't use any numbers >= 10 so that lexicographical order agrees with
        # numeric order in this test, for the type dtype == tf.string.

        # [3 7 5 3 1]
        # [2 6 5 4]
        # []
        # [9 8]
        sp_a = sparse_tensor_lib.SparseTensor(
            indices=constant_op.constant(
                [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1],
                 [1, 2], [1, 3], [3, 0], [3, 1]],
                dtype=dtypes.int64),
            values=_constant([3, 7, 5, 3, 1, 2, 6, 5, 4, 9, 8], dtype),
            dense_shape=constant_op.constant([4, 5], dtype=dtypes.int64))

        # [9 7]
        # [5 2 0]
        # [6]
        # []
        sp_b = sparse_tensor_lib.SparseTensor(
            indices=constant_op.constant(
                [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0]],
                dtype=dtypes.int64),
            values=_constant([9, 7, 5, 2, 0, 6], dtype),
            dense_shape=constant_op.constant([4, 3], dtype=dtypes.int64))
        # The union should be
        # [1 3 5 7 9]
        # [0 2 4 5 6]
        # [6]
        # [8 9]
        result = sets.set_union(sp_a, sp_b)
        self.assertAllEqual(result.dense_shape, [4, 5])
        self.assertAllEqual(
            result.indices,
            [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1], [1, 2],
             [1, 3], [1, 4], [2, 0], [3, 0], [3, 1]])
        self.assertAllEqual(
            result.values,
            _constant([1, 3, 5, 7, 9, 0, 2, 4, 5, 6, 6, 8, 9], dtype))
예제 #6
0
 def _set_union_count(self, a, b):
     op = sets.set_size(sets.set_union(a, b))
     with self.test_session() as sess:
         return sess.run(op)
예제 #7
0
 def _set_union_count(self, a, b):
   op = sets.set_size(sets.set_union(a, b))
   with self.cached_session() as sess:
     return sess.run(op)
예제 #8
0
 def _set_union_count(self, a, b):
   op = sets.set_size(sets.set_union(a, b))
   with self.cached_session() as sess:
     return self.evaluate(op)