def _set_union(self, a, b): # Validate that we get the same results with or without `validate_indices`, # and with a & b swapped. ops = ( sets.set_union(a, b, validate_indices=True), sets.set_union(a, b, validate_indices=False), sets.set_union(b, a, validate_indices=True), sets.set_union(b, a, validate_indices=False), ) for op in ops: self._assert_static_shapes(a, op) return self._run_equivalent_set_ops(ops)
def _set_union(self, a, b): # Validate that we get the same results with or without `validate_indices`, # and with a & b swapped. ops = ( sets.set_union( a, b, validate_indices=True), sets.set_union( a, b, validate_indices=False), sets.set_union( b, a, validate_indices=True), sets.set_union( b, a, validate_indices=False),) for op in ops: self._assert_static_shapes(a, op) return self._run_equivalent_set_ops(ops)
def _set_union(self, a, b): # Validate that we get the same results with or without `validate_indices`, # and with a & b swapped. ops = ( sets.set_union(a, b, validate_indices=True), sets.set_union(a, b, validate_indices=False), sets.set_union(b, a, validate_indices=True), sets.set_union(b, a, validate_indices=False), ) for op in ops: self._assert_shapes(a, op) with self.test_session() as sess: results = sess.run(ops) for i in range(1, 4): self.assertAllEqual(results[0].indices, results[i].indices) self.assertAllEqual(results[0].values, results[i].values) self.assertAllEqual(results[0].dense_shape, results[i].dense_shape) return results[0]
def _set_union(self, a, b): # Validate that we get the same results with or without `validate_indices`, # and with a & b swapped. ops = ( sets.set_union( a, b, validate_indices=True), sets.set_union( a, b, validate_indices=False), sets.set_union( b, a, validate_indices=True), sets.set_union( b, a, validate_indices=False),) for op in ops: self._assert_shapes(a, op) with self.test_session() as sess: results = sess.run(ops) for i in range(1, 4): self.assertAllEqual(results[0].indices, results[i].indices) self.assertAllEqual(results[0].values, results[i].values) self.assertAllEqual(results[0].dense_shape, results[i].dense_shape) return results[0]
def test_set_union_output_is_sorted(self, dtype): # We don't use any numbers >= 10 so that lexicographical order agrees with # numeric order in this test, for the type dtype == tf.string. # [3 7 5 3 1] # [2 6 5 4] # [] # [9 8] sp_a = sparse_tensor_lib.SparseTensor( indices=constant_op.constant( [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1], [1, 2], [1, 3], [3, 0], [3, 1]], dtype=dtypes.int64), values=_constant([3, 7, 5, 3, 1, 2, 6, 5, 4, 9, 8], dtype), dense_shape=constant_op.constant([4, 5], dtype=dtypes.int64)) # [9 7] # [5 2 0] # [6] # [] sp_b = sparse_tensor_lib.SparseTensor( indices=constant_op.constant( [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0]], dtype=dtypes.int64), values=_constant([9, 7, 5, 2, 0, 6], dtype), dense_shape=constant_op.constant([4, 3], dtype=dtypes.int64)) # The union should be # [1 3 5 7 9] # [0 2 4 5 6] # [6] # [8 9] result = sets.set_union(sp_a, sp_b) self.assertAllEqual(result.dense_shape, [4, 5]) self.assertAllEqual( result.indices, [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [2, 0], [3, 0], [3, 1]]) self.assertAllEqual( result.values, _constant([1, 3, 5, 7, 9, 0, 2, 4, 5, 6, 6, 8, 9], dtype))
def _set_union_count(self, a, b): op = sets.set_size(sets.set_union(a, b)) with self.test_session() as sess: return sess.run(op)
def _set_union_count(self, a, b): op = sets.set_size(sets.set_union(a, b)) with self.cached_session() as sess: return sess.run(op)
def _set_union_count(self, a, b): op = sets.set_size(sets.set_union(a, b)) with self.cached_session() as sess: return self.evaluate(op)