def test_function( statement: FunctionElement, output: str, dialect: SQLCompiler ) -> None: compiled = statement.compile( dialect=dialect, compile_kwargs={"literal_binds": True} ) eq_ignore_whitespace(str(compiled), output)
class CoreFixtures(object): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. fixtures = [ lambda: ( column("q"), column("x"), column("q", Integer), column("q", String), ), lambda: (~column("q", Boolean), ~column("p", Boolean)), lambda: ( table_a.c.a.label("foo"), table_a.c.a.label("bar"), table_a.c.b.label("foo"), ), lambda: ( _label_reference(table_a.c.a.desc()), _label_reference(table_a.c.a.asc()), ), lambda: (_textual_label_reference("a"), _textual_label_reference("b")), lambda: ( text("select a, b from table").columns(a=Integer, b=String), text("select a, b, c from table").columns( a=Integer, b=String, c=Integer), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=Integer)), text("select a, b, c from table where foo=:foo").bindparams( bindparam("foo", type_=Integer)), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=String)), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), column("z") + column("x"), column("z") - column("x"), column("x") - column("z"), column("z") > column("x"), column("x").in_([5, 7]), column("x").in_([10, 7, 8]), # note these two are mathematically equivalent but for now they # are considered to be different column("z") >= column("x"), column("x") <= column("z"), column("q").between(5, 6), column("q").between(5, 6, symmetric=True), column("q").like("somstr"), column("q").like("somstr", escape="\\"), column("q").like("somstr", escape="X"), ), lambda: ( table_a.c.a, table_a.c.a._annotate({"orm": True}), table_a.c.a._annotate({ "orm": True })._annotate({"bar": False}), table_a.c.a._annotate({ "orm": True, "parententity": MyEntity("a", table_a) }), table_a.c.a._annotate({ "orm": True, "parententity": MyEntity("b", table_a) }), table_a.c.a._annotate( { "orm": True, "parententity": MyEntity("b", select([table_a])) }), table_a.c.a._annotate({ "orm": True, "parententity": MyEntity("b", select([table_a]).where(table_a.c.a == 5)), }), ), lambda: ( table_a, table_a._annotate({"orm": True}), table_a._annotate({ "orm": True })._annotate({"bar": False}), table_a._annotate({ "orm": True, "parententity": MyEntity("a", table_a) }), table_a._annotate({ "orm": True, "parententity": MyEntity("b", table_a) }), table_a._annotate({ "orm": True, "parententity": MyEntity("b", select([table_a])) }), ), lambda: ( table("a", column("x"), column("y")), table("a", column("x"), column("y"))._annotate({"orm": True}), table("b", column("x"), column("y"))._annotate({"orm": True}), ), lambda: ( cast(column("q"), Integer), cast(column("q"), Float), cast(column("p"), Integer), ), lambda: ( bindparam("x"), bindparam("y"), bindparam("x", type_=Integer), bindparam("x", type_=String), bindparam(None), ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: (func.current_date(), func.current_time()), lambda: ( func.next_value(Sequence("q")), func.next_value(Sequence("p")), ), lambda: (True_(), False_()), lambda: (Null(), ), lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), lambda: (FunctionElement(5), FunctionElement(5, 6)), lambda: (func.count(), func.not_count()), lambda: (func.char_length("abc"), func.char_length("def")), lambda: (GenericFunction("a", "b"), GenericFunction("a")), lambda: (CollationClause("foobar"), CollationClause("batbar")), lambda: ( type_coerce(column("q", Integer), String), type_coerce(column("q", Integer), Float), type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b, column("q")), ), lambda: ( func.is_equal("a", "b").as_comparison(1, 2), func.is_equal("a", "c").as_comparison(1, 2), func.is_equal("a", "b").as_comparison(2, 1), func.is_equal("a", "b", "c").as_comparison(1, 2), func.foobar("a", "b").as_comparison(1, 2), ), lambda: ( func.row_number().over(order_by=table_a.c.a), func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over(order_by=table_a.c.a, partition_by=table_a.c.b), ), lambda: ( func.count(1).filter(table_a.c.a == 5), func.count(1).filter(table_a.c.a == 10), func.foob(1).filter(table_a.c.a == 10), ), lambda: ( and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), ), lambda: ( case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), case(whens=[ (table_a.c.a == 5, 10), (table_a.c.b == 10, 20), (table_a.c.a == 9, 12), ]), case( whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], else_=30, ), case({ "wendy": "W", "jack": "J" }, value=table_a.c.a, else_="E"), case({ "wendy": "W", "jack": "J" }, value=table_a.c.b, else_="E"), case({ "wendy_w": "W", "jack": "J" }, value=table_a.c.a, else_="E"), ), lambda: ( extract("foo", table_a.c.a), extract("foo", table_a.c.b), extract("bar", table_a.c.a), ), lambda: ( Slice(1, 2, 5), Slice(1, 5, 5), Slice(1, 5, 10), Slice(2, 10, 15), ), lambda: ( select([table_a.c.a]), select([table_a.c.a, table_a.c.b]), select([table_a.c.b, table_a.c.a]), select([table_a.c.b, table_a.c.a]).apply_labels(), select([table_a.c.a]).where(table_a.c.b == 5), select([table_a.c.a]).where(table_a.c.b == 5).where(table_a.c.a == 10), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update( nowait=True), select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), select([table_a.c.a]).where(table_a.c.b == 5).correlate_except( table_b), ), lambda: ( future_select(table_a.c.a), future_select(table_a.c.a).join(table_b, table_a.c.a == table_b.c.a ), future_select(table_a.c.a).join_from(table_a, table_b, table_a.c.a == table_b.c.a), future_select(table_a.c.a).join_from(table_a, table_b), future_select(table_a.c.a).join_from(table_c, table_b), future_select(table_a.c.a).join( table_b, table_a.c.a == table_b.c.a).join( table_c, table_b.c.b == table_c.c.x), future_select(table_a.c.a).join(table_b), future_select(table_a.c.a).join(table_c), future_select(table_a.c.a).join(table_b, table_a.c.a == table_b.c.b ), future_select(table_a.c.a).join(table_c, table_a.c.a == table_c.c.x ), ), lambda: ( select([table_a.c.a]).cte(), select([table_a.c.a]).cte(recursive=True), select([table_a.c.a]).cte(name="some_cte", recursive=True), select([table_a.c.a]).cte(name="some_cte"), select([table_a.c.a]).cte(name="some_cte").alias("other_cte"), select([table_a.c.a]).cte(name="some_cte").union_all( select([table_a.c.a])), select([table_a.c.a]).cte(name="some_cte").union_all( select([table_a.c.b])), select([table_a.c.a]).lateral(), select([table_a.c.a]).lateral(name="bar"), table_a.tablesample(func.bernoulli(1)), table_a.tablesample(func.bernoulli(1), seed=func.random()), table_a.tablesample(func.bernoulli(1), seed=func.other_random()), table_a.tablesample(func.hoho(1)), table_a.tablesample(func.bernoulli(1), name="bar"), table_a.tablesample( func.bernoulli(1), name="bar", seed=func.random()), ), lambda: ( table_a.insert(), table_a.insert().values({})._annotate({"nocache": True}), table_b.insert(), table_b.insert().with_dialect_options(sqlite_foo="some value"), table_b.insert().from_select(["a", "b"], select([table_a])), table_b.insert().from_select(["a", "b"], select([table_a]).where(table_a.c.a > 5)), table_b.insert().from_select(["a", "b"], select([table_b])), table_b.insert().from_select(["c", "d"], select([table_a])), table_b.insert().returning(table_b.c.a), table_b.insert().returning(table_b.c.a, table_b.c.b), table_b.insert().inline(), table_b.insert().prefix_with("foo"), table_b.insert().with_hint("RUNFAST"), table_b.insert().values(a=5, b=10), table_b.insert().values(a=5), table_b.insert().values({ table_b.c.a: 5, "b": 10 })._annotate({"nocache": True}), table_b.insert().values(a=7, b=10), table_b.insert().values(a=5, b=10).inline(), table_b.insert().values([{ "a": 5, "b": 10 }, { "a": 8, "b": 12 }])._annotate({"nocache": True}), table_b.insert().values([{ "a": 9, "b": 10 }, { "a": 8, "b": 7 }])._annotate({"nocache": True}), table_b.insert().values([(5, 10), (8, 12)])._annotate({"nocache": True}), table_b.insert().values([(5, 9), (5, 12)])._annotate({"nocache": True}), ), lambda: ( table_b.update(), table_b.update().where(table_b.c.a == 5), table_b.update().where(table_b.c.b == 5), table_b.update().where(table_b.c.b == 5).with_dialect_options( mysql_limit=10), table_b.update().where(table_b.c.b == 5).with_dialect_options( mysql_limit=10, sqlite_foo="some value"), table_b.update().where(table_b.c.a == 5).values(a=5, b=10), table_b.update().where(table_b.c.a == 5).values(a=5, b=10, c=12), table_b.update().where(table_b.c.b == 5).values(a=5, b=10). _annotate({"nocache": True}), table_b.update().values(a=5, b=10), table_b.update().values({ "a": 5, table_b.c.b: 10 })._annotate({"nocache": True}), table_b.update().values(a=7, b=10), table_b.update().ordered_values(("a", 5), ("b", 10)), table_b.update().ordered_values(("b", 10), ("a", 5)), table_b.update().ordered_values((table_b.c.a, 5), ("b", 10)), ), lambda: ( table_b.delete(), table_b.delete().with_dialect_options(sqlite_foo="some value"), table_b.delete().where(table_b.c.a == 5), table_b.delete().where(table_b.c.b == 5), ), lambda: ( values( column("mykey", Integer), column("mytext", String), column("myint", Integer), name="myvalues", ).data([(1, "textA", 99), (2, "textB", 88)])._annotate({"nocache": True}), values( column("mykey", Integer), column("mytext", String), column("myint", Integer), name="myothervalues", ).data([(1, "textA", 99), (2, "textB", 88)])._annotate({"nocache": True}), values( column("mykey", Integer), column("mytext", String), column("myint", Integer), name="myvalues", ).data([(1, "textA", 89), (2, "textG", 88)])._annotate({"nocache": True}), values( column("mykey", Integer), column("mynottext", String), column("myint", Integer), name="myvalues", ).data([(1, "textA", 99), (2, "textB", 88)])._annotate({"nocache": True}), # TODO: difference in type # values( # [ # column("mykey", Integer), # column("mytext", Text), # column("myint", Integer), # ], # (1, "textA", 99), # (2, "textB", 88), # alias_name="myvalues", # ), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).prefix_with("foo"), select([table_a.c.a]).prefix_with("foo", dialect="mysql"), select([table_a.c.a]).prefix_with("foo", dialect="postgresql"), select([table_a.c.a]).prefix_with("bar"), select([table_a.c.a]).suffix_with("bar"), ), lambda: ( select([table_a_2.c.a]), select([table_a_2_fs.c.a]), select([table_a_2_bs.c.a]), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).with_hint(None, "some hint"), select([table_a.c.a]).with_hint(None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some hint"), select([table_a.c.a]).with_hint(table_a, "some hint").with_hint( None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some other hint"), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="mysql"), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="postgresql"), ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join(table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)), table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), ), lambda: ( table_a.alias("a"), table_a.alias("b"), table_a.alias(), table_b.alias("a"), select([table_a.c.a]).alias("a"), ), lambda: ( FromGrouping(table_a.alias("a")), FromGrouping(table_a.alias("b")), ), lambda: ( SelectStatementGrouping(select([table_a])), SelectStatementGrouping(select([table_b])), ), lambda: ( select([table_a.c.a]).scalar_subquery(), select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(), ), lambda: ( exists().where(table_a.c.a == 5), exists().where(table_a.c.b == 5), ), lambda: ( union(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), union_all(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a])), union( select([table_a.c.a]), select([table_a.c.b]).where(table_a.c.b > 5), ), ), lambda: ( table("a", column("x"), column("y")), table("a", column("y"), column("x")), table("b", column("x"), column("y")), table("a", column("x"), column("y"), column("z")), table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), lambda: (table_a, table_b), ] dont_compare_values_fixtures = [ lambda: ( # note the in_(...) all have different column names becuase # otherwise all IN expressions would compare as equivalent column("x").in_(random_choices(range(10), k=3)), column("y").in_( bindparam( "q", random_choices(range(10), k=random.randint(0, 7)), expanding=True, )), column("z").in_(random_choices(range(10), k=random.randint(0, 7))), column("x") == random.randint(1, 10), ) ] def _complex_fixtures(): def one(): a1 = table_a.alias() a2 = table_b_like_a.alias() stmt = (select([table_a.c.a, a1.c.b, a2.c.b]).where(table_a.c.b == a1.c.b).where( a1.c.b == a2.c.b).where(a1.c.a == 5)) return stmt def one_diff(): a1 = table_b_like_a.alias() a2 = table_a.alias() stmt = (select([table_a.c.a, a1.c.b, a2.c.b]).where(table_a.c.b == a1.c.b).where( a1.c.b == a2.c.b).where(a1.c.a == 5)) return stmt def two(): inner = one().subquery() stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from( table_b.join(inner, table_b.c.b == inner.c.b)) return stmt def three(): a1 = table_a.alias() a2 = table_a.alias() ex = exists().where(table_b.c.b == a1.c.a) stmt = (select([a1.c.a, a2.c.a]).select_from( a1.join(a2, a1.c.b == a2.c.b)).where(ex)) return stmt return [one(), one_diff(), two(), three()] fixtures.append(_complex_fixtures) def _statements_w_context_options_fixtures(): return [ select([table_a])._add_context_option(opt1, True), select([table_a])._add_context_option(opt1, 5), select([table_a])._add_context_option(opt1, True)._add_context_option( opt2, True), select([table_a ])._add_context_option(opt1, True)._add_context_option(opt2, 5), select([table_a])._add_context_option(opt3, True), ] fixtures.append(_statements_w_context_options_fixtures) def _statements_w_anonymous_col_names(): def one(): c = column("q") l = c.label(None) # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d subq = select([l]).subquery() # this creates a ColumnClause as a proxy to the Label() that has # an anoymous name, so the column has one too. anon_col = subq.c[0] # then when BindParameter is created, it checks the label # and doesn't double up on the anonymous name which is uncachable return anon_col > 5 def two(): c = column("p") l = c.label(None) # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d subq = select([l]).subquery() # this creates a ColumnClause as a proxy to the Label() that has # an anoymous name, so the column has one too. anon_col = subq.c[0] # then when BindParameter is created, it checks the label # and doesn't double up on the anonymous name which is uncachable return anon_col > 5 def three(): l1, l2 = table_a.c.a.label(None), table_a.c.b.label(None) stmt = select([table_a.c.a, table_a.c.b, l1, l2]) subq = stmt.subquery() return select([subq]).where(subq.c[2] == 10) return ( one(), two(), three(), ) fixtures.append(_statements_w_anonymous_col_names)
class CoreFixtures(object): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. fixtures = [ lambda: ( column("q"), column("x"), column("q", Integer), column("q", String), ), lambda: (~column("q", Boolean), ~column("p", Boolean)), lambda: ( table_a.c.a.label("foo"), table_a.c.a.label("bar"), table_a.c.b.label("foo"), ), lambda: ( _label_reference(table_a.c.a.desc()), _label_reference(table_a.c.a.asc()), ), lambda: (_textual_label_reference("a"), _textual_label_reference("b")), lambda: ( text("select a, b from table").columns(a=Integer, b=String), text("select a, b, c from table").columns( a=Integer, b=String, c=Integer ), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=Integer) ), text("select a, b, c from table where foo=:foo").bindparams( bindparam("foo", type_=Integer) ), text("select a, b, c from table where foo=:bar").bindparams( bindparam("bar", type_=String) ), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), column("z") + column("x"), column("z") - column("x"), column("x") - column("z"), column("z") > column("x"), column("x").in_([5, 7]), column("x").in_([10, 7, 8]), # note these two are mathematically equivalent but for now they # are considered to be different column("z") >= column("x"), column("x") <= column("z"), column("q").between(5, 6), column("q").between(5, 6, symmetric=True), column("q").like("somstr"), column("q").like("somstr", escape="\\"), column("q").like("somstr", escape="X"), ), lambda: ( table_a.c.a, table_a.c.a._annotate({"orm": True}), table_a.c.a._annotate({"orm": True})._annotate({"bar": False}), table_a.c.a._annotate( {"orm": True, "parententity": MyEntity("a", table_a)} ), table_a.c.a._annotate( {"orm": True, "parententity": MyEntity("b", table_a)} ), table_a.c.a._annotate( {"orm": True, "parententity": MyEntity("b", select([table_a]))} ), ), lambda: ( cast(column("q"), Integer), cast(column("q"), Float), cast(column("p"), Integer), ), lambda: ( bindparam("x"), bindparam("y"), bindparam("x", type_=Integer), bindparam("x", type_=String), bindparam(None), ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: (func.current_date(), func.current_time()), lambda: ( func.next_value(Sequence("q")), func.next_value(Sequence("p")), ), lambda: (True_(), False_()), lambda: (Null(),), lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), lambda: (FunctionElement(5), FunctionElement(5, 6)), lambda: (func.count(), func.not_count()), lambda: (func.char_length("abc"), func.char_length("def")), lambda: (GenericFunction("a", "b"), GenericFunction("a")), lambda: (CollationClause("foobar"), CollationClause("batbar")), lambda: ( type_coerce(column("q", Integer), String), type_coerce(column("q", Integer), Float), type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), func.percentile_cont(0.5).within_group( table_a.c.a, table_a.c.b, column("q") ), ), lambda: ( func.is_equal("a", "b").as_comparison(1, 2), func.is_equal("a", "c").as_comparison(1, 2), func.is_equal("a", "b").as_comparison(2, 1), func.is_equal("a", "b", "c").as_comparison(1, 2), func.foobar("a", "b").as_comparison(1, 2), ), lambda: ( func.row_number().over(order_by=table_a.c.a), func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over( order_by=table_a.c.a, partition_by=table_a.c.b ), ), lambda: ( func.count(1).filter(table_a.c.a == 5), func.count(1).filter(table_a.c.a == 10), func.foob(1).filter(table_a.c.a == 10), ), lambda: ( and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), ), lambda: ( case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), case( whens=[ (table_a.c.a == 5, 10), (table_a.c.b == 10, 20), (table_a.c.a == 9, 12), ] ), case( whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], else_=30, ), case({"wendy": "W", "jack": "J"}, value=table_a.c.a, else_="E"), case({"wendy": "W", "jack": "J"}, value=table_a.c.b, else_="E"), case({"wendy_w": "W", "jack": "J"}, value=table_a.c.a, else_="E"), ), lambda: ( extract("foo", table_a.c.a), extract("foo", table_a.c.b), extract("bar", table_a.c.a), ), lambda: ( Slice(1, 2, 5), Slice(1, 5, 5), Slice(1, 5, 10), Slice(2, 10, 15), ), lambda: ( select([table_a.c.a]), select([table_a.c.a, table_a.c.b]), select([table_a.c.b, table_a.c.a]), select([table_a.c.a]).where(table_a.c.b == 5), select([table_a.c.a]) .where(table_a.c.b == 5) .where(table_a.c.a == 10), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), select([table_a.c.a]) .where(table_a.c.b == 5) .with_for_update(nowait=True), select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), select([table_a.c.a]) .where(table_a.c.b == 5) .correlate_except(table_b), ), lambda: ( select([table_a.c.a]).cte(), select([table_a.c.a]).cte(recursive=True), select([table_a.c.a]).cte(name="some_cte", recursive=True), select([table_a.c.a]).cte(name="some_cte"), select([table_a.c.a]).cte(name="some_cte").alias("other_cte"), select([table_a.c.a]) .cte(name="some_cte") .union_all(select([table_a.c.a])), select([table_a.c.a]) .cte(name="some_cte") .union_all(select([table_a.c.b])), select([table_a.c.a]).lateral(), select([table_a.c.a]).lateral(name="bar"), table_a.tablesample(func.bernoulli(1)), table_a.tablesample(func.bernoulli(1), seed=func.random()), table_a.tablesample(func.bernoulli(1), seed=func.other_random()), table_a.tablesample(func.hoho(1)), table_a.tablesample(func.bernoulli(1), name="bar"), table_a.tablesample( func.bernoulli(1), name="bar", seed=func.random() ), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).prefix_with("foo"), select([table_a.c.a]).prefix_with("foo", dialect="mysql"), select([table_a.c.a]).prefix_with("foo", dialect="postgresql"), select([table_a.c.a]).prefix_with("bar"), select([table_a.c.a]).suffix_with("bar"), ), lambda: ( select([table_a_2.c.a]), select([table_a_2_fs.c.a]), select([table_a_2_bs.c.a]), ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).with_hint(None, "some hint"), select([table_a.c.a]).with_hint(None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some hint"), select([table_a.c.a]) .with_hint(table_a, "some hint") .with_hint(None, "some other hint"), select([table_a.c.a]).with_hint(table_a, "some other hint"), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="mysql" ), select([table_a.c.a]).with_hint( table_a, "some hint", dialect_name="postgresql" ), ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join( table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1) ), table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), ), lambda: ( table_a.alias("a"), table_a.alias("b"), table_a.alias(), table_b.alias("a"), select([table_a.c.a]).alias("a"), ), lambda: ( FromGrouping(table_a.alias("a")), FromGrouping(table_a.alias("b")), ), lambda: ( SelectStatementGrouping(select([table_a])), SelectStatementGrouping(select([table_b])), ), lambda: ( select([table_a.c.a]).scalar_subquery(), select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(), ), lambda: ( exists().where(table_a.c.a == 5), exists().where(table_a.c.b == 5), ), lambda: ( union(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), union_all(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a])), union( select([table_a.c.a]), select([table_a.c.b]).where(table_a.c.b > 5), ), ), lambda: ( table("a", column("x"), column("y")), table("a", column("y"), column("x")), table("b", column("x"), column("y")), table("a", column("x"), column("y"), column("z")), table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), lambda: (table_a, table_b), ] dont_compare_values_fixtures = [ lambda: ( # note the in_(...) all have different column names becuase # otherwise all IN expressions would compare as equivalent column("x").in_(random_choices(range(10), k=3)), column("y").in_( bindparam( "q", random_choices(range(10), k=random.randint(0, 7)), expanding=True, ) ), column("z").in_(random_choices(range(10), k=random.randint(0, 7))), column("x") == random.randint(1, 10), ) ] def _complex_fixtures(): def one(): a1 = table_a.alias() a2 = table_b_like_a.alias() stmt = ( select([table_a.c.a, a1.c.b, a2.c.b]) .where(table_a.c.b == a1.c.b) .where(a1.c.b == a2.c.b) .where(a1.c.a == 5) ) return stmt def one_diff(): a1 = table_b_like_a.alias() a2 = table_a.alias() stmt = ( select([table_a.c.a, a1.c.b, a2.c.b]) .where(table_a.c.b == a1.c.b) .where(a1.c.b == a2.c.b) .where(a1.c.a == 5) ) return stmt def two(): inner = one().subquery() stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from( table_b.join(inner, table_b.c.b == inner.c.b) ) return stmt def three(): a1 = table_a.alias() a2 = table_a.alias() ex = exists().where(table_b.c.b == a1.c.a) stmt = ( select([a1.c.a, a2.c.a]) .select_from(a1.join(a2, a1.c.b == a2.c.b)) .where(ex) ) return stmt return [one(), one_diff(), two(), three()] fixtures.append(_complex_fixtures)
class CompareAndCopyTest(fixtures.TestBase): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. fixtures = [ lambda: ( column("q"), column("x"), column("q", Integer), column("q", String), ), lambda: (~column("q", Boolean), ~column("p", Boolean)), lambda: ( table_a.c.a.label("foo"), table_a.c.a.label("bar"), table_a.c.b.label("foo"), ), lambda: ( _label_reference(table_a.c.a.desc()), _label_reference(table_a.c.a.asc()), ), lambda: (_textual_label_reference("a"), _textual_label_reference("b")), lambda: ( text("select a, b from table").columns(a=Integer, b=String), text("select a, b, c from table").columns( a=Integer, b=String, c=Integer), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), ), lambda: ( cast(column("q"), Integer), cast(column("q"), Float), cast(column("p"), Integer), ), lambda: ( bindparam("x"), bindparam("y"), bindparam("x", type_=Integer), bindparam("x", type_=String), bindparam(None), ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: (func.current_date(), func.current_time()), lambda: ( func.next_value(Sequence("q")), func.next_value(Sequence("p")), ), lambda: (True_(), False_()), lambda: (Null(), ), lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), lambda: (FunctionElement(5), FunctionElement(5, 6)), lambda: (func.count(), func.not_count()), lambda: (func.char_length("abc"), func.char_length("def")), lambda: (GenericFunction("a", "b"), GenericFunction("a")), lambda: (CollationClause("foobar"), CollationClause("batbar")), lambda: ( type_coerce(column("q", Integer), String), type_coerce(column("q", Integer), Float), type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), lambda: (tuple_([1, 2]), tuple_([3, 4])), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), func.percentile_cont(0.5).within_group(table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b, column("q")), ), lambda: ( func.is_equal("a", "b").as_comparison(1, 2), func.is_equal("a", "c").as_comparison(1, 2), func.is_equal("a", "b").as_comparison(2, 1), func.is_equal("a", "b", "c").as_comparison(1, 2), func.foobar("a", "b").as_comparison(1, 2), ), lambda: ( func.row_number().over(order_by=table_a.c.a), func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), func.row_number().over(order_by=table_a.c.b), func.row_number().over(order_by=table_a.c.a, partition_by=table_a.c.b), ), lambda: ( func.count(1).filter(table_a.c.a == 5), func.count(1).filter(table_a.c.a == 10), func.foob(1).filter(table_a.c.a == 10), ), lambda: ( and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), ), lambda: ( case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), case(whens=[ (table_a.c.a == 5, 10), (table_a.c.b == 10, 20), (table_a.c.a == 9, 12), ]), case( whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], else_=30, ), case({ "wendy": "W", "jack": "J" }, value=table_a.c.a, else_="E"), case({ "wendy": "W", "jack": "J" }, value=table_a.c.b, else_="E"), case({ "wendy_w": "W", "jack": "J" }, value=table_a.c.a, else_="E"), ), lambda: ( extract("foo", table_a.c.a), extract("foo", table_a.c.b), extract("bar", table_a.c.a), ), lambda: ( Slice(1, 2, 5), Slice(1, 5, 5), Slice(1, 5, 10), Slice(2, 10, 15), ), lambda: ( select([table_a.c.a]), select([table_a.c.a, table_a.c.b]), select([table_a.c.b, table_a.c.a]), select([table_a.c.a]).where(table_a.c.b == 5), select([table_a.c.a]).where(table_a.c.b == 5).where(table_a.c.a == 10), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), select([table_a.c.a]).where(table_a.c.b == 5).with_for_update( nowait=True), select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), select([table_a.c.a]).where(table_a.c.b == 5).correlate_except( table_b), ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join(table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)), table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), ), lambda: ( table_a.alias("a"), table_a.alias("b"), table_a.alias(), table_b.alias("a"), select([table_a.c.a]).alias("a"), ), lambda: ( FromGrouping(table_a.alias("a")), FromGrouping(table_a.alias("b")), ), lambda: ( select([table_a.c.a]).as_scalar(), select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(), ), lambda: ( exists().where(table_a.c.a == 5), exists().where(table_a.c.b == 5), ), lambda: ( union(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), union_all(select([table_a.c.a]), select([table_a.c.b])), union(select([table_a.c.a])), union( select([table_a.c.a]), select([table_a.c.b]).where(table_a.c.b > 5), ), ), lambda: ( table("a", column("x"), column("y")), table("a", column("y"), column("x")), table("b", column("x"), column("y")), table("a", column("x"), column("y"), column("z")), table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), lambda: ( Table("a", MetaData(), Column("q", Integer), Column("b", String)), Table("b", MetaData(), Column("q", Integer), Column("b", String)), ), ] @classmethod def setup_class(cls): # TODO: we need to get dialects here somehow, perhaps in test_suite? [ importlib.import_module("sqlalchemy.dialects.%s" % d) for d in dialects.__all__ if not d.startswith("_") ] def test_all_present(self): need = set( cls for cls in class_hierarchy(ClauseElement) if issubclass(cls, (ColumnElement, Selectable)) and "__init__" in cls.__dict__ and not issubclass(cls, (Annotated)) and "orm" not in cls.__module__ and "crud" not in cls.__module__ and "dialects" not in cls.__module__ # TODO: dialects? ).difference({ColumnElement, UnaryExpression}) for fixture in self.fixtures: case_a = fixture() for elem in case_a: for mro in type(elem).__mro__: need.discard(mro) is_false(bool(need), "%d Remaining classes: %r" % (len(need), need)) def test_compare(self): for fixture in self.fixtures: case_a = fixture() case_b = fixture() for a, b in itertools.combinations_with_replacement( range(len(case_a)), 2): if a == b: is_true( case_a[a].compare(case_b[b], arbitrary_expression=True), "%r != %r" % (case_a[a], case_b[b]), ) else: is_false( case_a[a].compare(case_b[b], arbitrary_expression=True), "%r == %r" % (case_a[a], case_b[b]), ) def test_cache_key(self): def assert_params_append(assert_params): def append(param): if param._value_required_for_cache: assert_params.append(param) else: is_(param.value, None) return append for fixture in self.fixtures: case_a = fixture() case_b = fixture() for a, b in itertools.combinations_with_replacement( range(len(case_a)), 2): assert_a_params = [] assert_b_params = [] visitors.traverse_depthfirst( case_a[a], {}, {"bindparam": assert_params_append(assert_a_params)}, ) visitors.traverse_depthfirst( case_b[b], {}, {"bindparam": assert_params_append(assert_b_params)}, ) if assert_a_params: assert_raises_message( NotImplementedError, "bindparams collection argument required ", case_a[a]._cache_key, ) if assert_b_params: assert_raises_message( NotImplementedError, "bindparams collection argument required ", case_b[b]._cache_key, ) if not assert_a_params and not assert_b_params: if a == b: eq_(case_a[a]._cache_key(), case_b[b]._cache_key()) else: ne_(case_a[a]._cache_key(), case_b[b]._cache_key()) def test_cache_key_gather_bindparams(self): for fixture in self.fixtures: case_a = fixture() case_b = fixture() # in the "bindparams" case, the cache keys for bound parameters # with only different values will be the same, but the params # themselves are gathered into a collection. for a, b in itertools.combinations_with_replacement( range(len(case_a)), 2): a_params = {"bindparams": []} b_params = {"bindparams": []} if a == b: a_key = case_a[a]._cache_key(**a_params) b_key = case_b[b]._cache_key(**b_params) eq_(a_key, b_key) if a_params["bindparams"]: for a_param, b_param in zip(a_params["bindparams"], b_params["bindparams"]): assert a_param.compare(b_param) else: a_key = case_a[a]._cache_key(**a_params) b_key = case_b[b]._cache_key(**b_params) if a_key == b_key: for a_param, b_param in zip(a_params["bindparams"], b_params["bindparams"]): if not a_param.compare(b_param): break else: assert False, "Bound parameters are all the same" else: ne_(a_key, b_key) assert_a_params = [] assert_b_params = [] visitors.traverse_depthfirst( case_a[a], {}, {"bindparam": assert_a_params.append}) visitors.traverse_depthfirst( case_b[b], {}, {"bindparam": assert_b_params.append}) # note we're asserting the order of the params as well as # if there are dupes or not. ordering has to be deterministic # and matches what a traversal would provide. eq_(a_params["bindparams"], assert_a_params) eq_(b_params["bindparams"], assert_b_params) def test_compare_col_identity(self): stmt1 = (select([table_a.c.a, table_b.c.b ]).where(table_a.c.a == table_b.c.b).alias()) stmt1_c = (select([table_a.c.a, table_b.c.b ]).where(table_a.c.a == table_b.c.b).alias()) stmt2 = union(select([table_a]), select([table_b])) stmt3 = select([table_b]) equivalents = {table_a.c.a: [table_b.c.a]} is_false( stmt1.compare(stmt2, use_proxies=True, equivalents=equivalents)) is_true( stmt1.compare(stmt1_c, use_proxies=True, equivalents=equivalents)) is_true((table_a.c.a == table_b.c.b).compare( stmt1.c.a == stmt1.c.b, use_proxies=True, equivalents=equivalents, )) def test_copy_internals(self): for fixture in self.fixtures: case_a = fixture() case_b = fixture() assert case_a[0].compare(case_b[0]) clone = case_a[0]._clone() clone._copy_internals() assert clone.compare(case_b[0]) stack = [clone] seen = {clone} found_elements = False while stack: obj = stack.pop(0) items = [ subelem for key, elem in clone.__dict__.items() if key != "_is_clone_of" and elem is not None for subelem in util.to_list(elem) if (isinstance(subelem, (ColumnElement, ClauseList)) and subelem not in seen and not isinstance( subelem, Immutable) and subelem is not case_a[0]) ] stack.extend(items) seen.update(items) if obj is not clone: found_elements = True # ensure the element will not compare as true obj.compare = lambda other, **kw: False obj.__visit_name__ = "dont_match" if found_elements: assert not clone.compare(case_b[0]) assert case_a[0].compare(case_b[0])
def __init__(self, data): self.data = data FunctionElement.__init__(self, self.data)