Ejemplo n.º 1
0
 def setup(self):
     from odps.df.expr.tests.core import MockTable
     schema = Schema.from_lists(types._data_types.keys(),
                                types._data_types.values())
     self.expr = CollectionExpr(_source_data=None, _schema=schema)
     self.sourced_expr = CollectionExpr(
         _source_data=MockTable(client=self.odps.rest), _schema=schema)
    def testMakeKV(self):
        from odps import types as odps_types
        data = [
            ['name1', 1.0, 3.0, None, 10.0, None, None],
            ['name1', None, 3.0, 5.1, None, None, None],
            ['name1', 7.1, None, None, None, 8.2, None],
            ['name2', None, 1.2, 1.5, None, None, None],
            ['name2', None, 1.0, None, None, None, 1.1],
        ]
        kv_cols = ['k1', 'k2', 'k3', 'k5', 'k7', 'k9']
        schema = Schema.from_lists(['name'] + kv_cols, [odps_types.string] +
                                   [odps_types.double] * 6)
        table_name = tn('pyodps_test_engine_make_kv')
        self.odps.delete_table(table_name, if_exists=True)
        table = self.odps.create_table(name=table_name, schema=schema)
        expr = CollectionExpr(_source_data=table,
                              _schema=odps_schema_to_df_schema(schema))
        try:
            self.odps.write_table(table, 0, data)
            expr1 = expr.to_kv(columns=kv_cols, kv_delim='=')

            res = self.engine.execute(expr1)
            result = self._get_result(res)

            expected = [
                ['name1', 'k1=1,k2=3,k5=10'],
                ['name1', 'k2=3,k3=5.1'],
                ['name1', 'k1=7.1,k7=8.2'],
                ['name2', 'k2=1.2,k3=1.5'],
                ['name2', 'k2=1,k9=1.1'],
            ]

            self.assertListEqual(result, expected)
        finally:
            table.drop()
Ejemplo n.º 3
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'))
        self.df_schema = schema
        self.schema = df_schema_to_odps_schema(schema)
        self.df = None
        self.expr = None

        self.engine = SQLAlchemyEngine()

        import sqlalchemy
        from sqlalchemy import create_engine

        self.sql_engine = engine = create_engine('postgres://localhost/pyodps')
        # self.sql_engine = engine = create_engine('mysql://localhost/pyodps')
        # self.sql_engine = engine = create_engine('sqlite://')
        self.conn = engine.connect()

        self.metadata = metadata = sqlalchemy.MetaData(bind=engine)
        columns = df_schema_to_sqlalchemy_columns(self.df_schema, engine=self.sql_engine)
        t = sqlalchemy.Table('pyodps_test_data', metadata, *columns)

        metadata.create_all()

        self.table = t
        self.expr = CollectionExpr(_source_data=self.table, _schema=self.df_schema)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass
        self.faked_bar = FakeBar()
    def testSVGFormatter(self):
        t = MockTable(name='pyodps_test_svg', schema=self.schema, _client=self.odps.rest)
        expr = CollectionExpr(_source_data=t, _schema=self.schema)

        expr1 = expr.groupby('name').agg(id=expr['id'].sum())
        expr2 = expr1['name', expr1.id + 3]

        engine = MixedEngine(self.odps)
        dag = engine.compile(expr2)
        nodes = dag.nodes()
        self.assertEqual(len(nodes), 1)
        expr3 = nodes[0].expr
        self.assertIsInstance(expr3, GroupByCollectionExpr)
        dot = ExprExecutionGraphFormatter(dag)._to_dot()
        self.assertNotIn('Projection', dot)

        expr1 = expr.groupby('name').agg(id=expr['id'].sum()).cache()
        expr2 = expr1['name', expr1.id + 3]

        engine = MixedEngine(self.odps)
        dag = engine.compile(expr2)
        nodes = dag.nodes()
        self.assertEqual(len(nodes), 2)
        dot = ExprExecutionGraphFormatter(dag)._to_dot()
        self.assertIn('Projection', dot)
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'birth', 'scale'][:5],
            datatypes('string', 'int64', 'float64', 'boolean', 'datetime',
                      'decimal')[:5])
        self.schema = df_schema_to_odps_schema(schema)
        table_name = tn('pyodps_test_%s' % str(uuid.uuid4()).replace('-', '_'))
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(name=table_name,
                                            schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        self.engine = SeahawksEngine(self.odps)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass

            def inc(self, *args, **kwargs):
                pass

            def status(self, *args, **kwargs):
                pass

        self.faked_bar = FakeBar()
Ejemplo n.º 6
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'))

        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        self.engine = ODPSEngine(self.odps)
    def testSVGFormatter(self):
        t = MockTable(name='pyodps_test_svg', schema=self.schema, _client=self.odps.rest)
        expr = CollectionExpr(_source_data=t, _schema=self.schema)

        expr1 = expr.groupby('name').agg(id=expr['id'].sum()).cache()
        expr2 = expr1['name', expr1.id + 3]

        engine = MixedEngine(self.odps)
        dag, expr3, _ = engine._compile(expr2)
        self.assertTrue(ExprExecutionGraphFormatter(expr3, dag)._to_dot())
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id'], datatypes('string', 'int64'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        schema2 = Schema.from_lists(['name2', 'id2'], datatypes('string', 'int64'))
        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema2)
    def testScaleValue(self):
        data = [
            ['name1', 4, 5.3],
            ['name2', 2, 3.5],
            ['name1', 4, 4.2],
            ['name1', 3, 2.2],
            ['name1', 3, 4.1],
        ]
        schema = Schema.from_lists(['name', 'id', 'fid'],
                                   [types.string, types.bigint, types.double])
        table_name = tn('pyodps_test_engine_scale_table')
        self.odps.delete_table(table_name, if_exists=True)
        table = self.odps.create_table(name=table_name, schema=schema)
        self.odps.write_table(table_name, 0, data)
        expr_input = CollectionExpr(_source_data=table,
                                    _schema=odps_schema_to_df_schema(schema))

        expr = expr_input.min_max_scale(columns=['fid'])

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['name1', 4, 1.0], ['name2', 2, 0.41935483870967744],
                    ['name1', 4, 0.6451612903225807], ['name1', 3, 0.0],
                    ['name1', 3, 0.6129032258064515]]

        result = sorted(result)
        expected = sorted(expected)

        for first, second in zip(result, expected):
            self.assertEqual(len(first), len(second))
            for it1, it2 in zip(first, second):
                self.assertAlmostEqual(it1, it2)

        expr = expr_input.std_scale(columns=['fid'])

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['name1', 4, 1.4213602653434203],
                    ['name2', 2, -0.3553400663358544],
                    ['name1', 4, 0.3355989515394193],
                    ['name1', 3, -1.6385125281042194],
                    ['name1', 3, 0.23689337755723686]]

        result = sorted(result)
        expected = sorted(expected)

        for first, second in zip(result, expected):
            self.assertEqual(len(first), len(second))
            for it1, it2 in zip(first, second):
                self.assertAlmostEqual(it1, it2)
Ejemplo n.º 10
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
            datatypes('string', 'int64', 'float64', 'boolean', 'decimal',
                      'datetime'))
        self.schema = df_schema_to_odps_schema(schema)
        table_name = tn('pyodps_test_selecter_table_%s' %
                        str(uuid.uuid4()).replace('-', '_'))
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(name=table_name,
                                            schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass

            def inc(self, *args, **kwargs):
                pass

            def status(self, *args, **kwargs):
                pass

        self.faked_bar = FakeBar()

        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])

        table_name = tn('pyodps_test_selecter_table2')
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2,
                                    _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [['name1', 4, -1], ['name2', 1, -2]]

        self.odps.write_table(table2, 0, data2)

        self.selecter = EngineSelecter()
Ejemplo n.º 11
0
    def testBizarreField(self):
        def my_func(row):
            return getattr(row, '012') * 2.0

        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', '012'],
                                   datatypes('string', 'int64', 'float64', 'float64'))

        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        expr = CollectionExpr(_source_data=table, _schema=schema)

        self.engine.compile(expr.apply(my_func, axis=1, names=['out_col'], types=['float64']))
        udtf = list(self.engine._ctx._func_to_udfs.values())[0]
        udtf = get_function(udtf, UDF_CLASS_NAME)
        self.assertEqual([20, 40],
                         runners.simple_run(udtf, [('name1', 1, None, 10), ('name2', 2, None, 20)]))
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)
Ejemplo n.º 13
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'))
        self.schema = df_schema_to_odps_schema(schema)
        table_name = 'pyodps_test_engine_table'
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(
                name='pyodps_test_engine_table', schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        self.engine = ODPSEngine(self.odps)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass
        self.faked_bar = FakeBar()
Ejemplo n.º 14
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
            datatypes('string', 'int64', 'float64', 'boolean', 'decimal',
                      'datetime'), ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table,
                                   _schema=Schema(columns=schema.columns))

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1,
                                    _schema=Schema(columns=schema.columns))

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2,
                                    _schema=Schema(columns=schema.columns))

        schema2 = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'],
                                    datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3,
                                    _schema=Schema(columns=schema2.columns))

        schema3 = Schema.from_lists(['id', 'name', 'relatives', 'hobbies'],
                                    datatypes('int64', 'string',
                                              'dict<string, string>',
                                              'list<string>'))
        table4 = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr4 = CollectionExpr(_source_data=table4, _schema=schema3)
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'),
                                   ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=Schema(columns=schema.columns))

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=Schema(columns=schema.columns))

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=Schema(columns=schema.columns))

        schema2 = Schema.from_lists(['name', 'id', 'fid'], datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'], datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=Schema(columns=schema2.columns))
Ejemplo n.º 16
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id'], datatypes('string', 'int64'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        schema2 = Schema.from_lists(['name2', 'id2'], datatypes('string', 'int64'))
        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema2)
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(["name", "id"], datatypes("string", "int64"))
        table = MockTable(name="pyodps_test_expr_table", schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        schema2 = Schema.from_lists(["name2", "id2"], datatypes("string", "int64"))
        table2 = MockTable(name="pyodps_test_expr_table2", schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema2)
Ejemplo n.º 18
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'))

        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        self.engine = ODPSEngine(self.odps)
Ejemplo n.º 19
0
    def testCallableColumn(self):
        from odps.df.expr.expressions import CallableColumn
        from odps.df.expr.collections import ProjectCollectionExpr

        schema = Schema.from_lists(['name', 'f1', 'append_id'],
                                   [types.string, types.float64, types.int64])
        expr = CollectionExpr(_source_data=None, _schema=schema)
        self.assertIsInstance(expr.append_id, CallableColumn)
        self.assertNotIsInstance(expr.f1, CallableColumn)

        projected = expr[expr.name, expr.append_id]
        self.assertIsInstance(projected, ProjectCollectionExpr)
        self.assertListEqual(projected.schema.names, ['name', 'append_id'])

        projected = expr[expr.name, expr.f1]
        self.assertNotIsInstance(projected.append_id, CallableColumn)

        appended = expr.append_id(id_col='id_col')
        self.assertIn('id_col', appended.schema)
Ejemplo n.º 20
0
    def testSVGFormatter(self):
        t = MockTable(name="pyodps_test_svg", schema=self.schema, _client=self.odps.rest)
        expr = CollectionExpr(_source_data=t, _schema=self.schema)

        expr1 = expr.groupby("name").agg(id=expr["id"].sum())
        expr2 = expr1["name", expr1.id + 3]

        engine = MixedEngine(self.odps)
        dag, expr3, _ = engine._compile(expr2)
        self.assertIsInstance(expr3, GroupByCollectionExpr)
        dot = ExprExecutionGraphFormatter(expr3, dag)._to_dot()
        self.assertNotIn("Projection", dot)

        expr1 = expr.groupby("name").agg(id=expr["id"].sum()).cache()
        expr2 = expr1["name", expr1.id + 3]

        engine = MixedEngine(self.odps)
        dag, expr3, _ = engine._compile(expr2)
        dot = ExprExecutionGraphFormatter(expr3, dag)._to_dot()
        self.assertIn("Projection", dot)
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'),
                                   ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=schema)

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema)

        schema2 = Schema.from_lists(['name', 'id', 'fid'], datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'], datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=schema2)

        self.maxDiff = None
Ejemplo n.º 22
0
    def testUnion(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name3', 5, -1],
            ['name4', 6, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr['name', 'id'].distinct().union(expr2[expr2.id2.rename('id'), 'name'])

            res = self.engine.execute(expr)
            result = self._get_result(res)

            expected = [
                ['name1', 4],
                ['name1', 3],
                ['name2', 2],
                ['name3', 5],
                ['name4', 6]
            ]

            result = sorted(result)
            expected = sorted(expected)

            self.assertEqual(len(result), len(expected))
            for e, r in zip(result, expected):
                self.assertEqual([to_str(t) for t in e],
                                 [to_str(t) for t in r])

        finally:
            table2.drop()
Ejemplo n.º 23
0
    def testUnion(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        data2 = [
            ['name3', 5, -1],
            ['name4', 6, -2]
        ]

        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    datatypes('string', 'int64', 'int64'))
        table_name = tn('pyodps_test_engine_table2')
        table2 = self._create_table_and_insert_data(table_name, schema2, data2)
        expr2 = CollectionExpr(_source_data=table2, _schema=schema2)

        self._gen_data(data=data)

        try:
            expr = self.expr['name', 'id'].distinct().union(expr2[expr2.id2.rename('id'), 'name'])

            res = self.engine.execute(expr)
            result = self._get_result(res)

            expected = [
                ['name1', 4],
                ['name1', 3],
                ['name2', 2],
                ['name3', 5],
                ['name4', 6]
            ]

            result = sorted(result)
            expected = sorted(expected)

            self.assertEqual(len(result), len(expected))
            for e, r in zip(result, expected):
                self.assertEqual([to_str(t) for t in e],
                                 [to_str(t) for t in r])

        finally:
            [conn.close() for conn in _engine_to_connections.values()]
            table2.drop()
Ejemplo n.º 24
0
    def testJoin(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name1', 4, -1],
            ['name2', 1, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr.join(expr2)['name', 'id2']

            res = self.engine.execute(expr)
            result = self._get_result(res)

            self.assertEqual(len(result), 5)
            expected = [
                [to_str('name1'), 4],
                [to_str('name2'), 1]
            ]
            self.assertTrue(all(it in expected for it in result))

            expr = self.expr.join(expr2, on=['name', ('id', 'id2')])[self.expr.name, expr2.id2]
            res = self.engine.execute(expr)
            result = self._get_result(res)
            self.assertEqual(len(result), 2)
            expected = [to_str('name1'), 4]
            self.assertTrue(all(it == expected for it in result))

        finally:
            table2.drop()
Ejemplo n.º 25
0
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'))
        self.schema = df_schema_to_odps_schema(schema)
        table_name = 'pyodps_test_engine_table'
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(
                name='pyodps_test_engine_table', schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        self.engine = ODPSEngine(self.odps)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass
        self.faked_bar = FakeBar()
    def testJoinGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None],
            ['name2', 2, 3.5, None, None],
            ['name1', 4, 4.2, None, None],
            ['name1', 3, 2.2, None, None],
            ['name1', 3, 4.1, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])

        table_name = tn('pyodps_test_engine_table2')
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2,
                               _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [['name1', 4, -1], ['name2', 1, -2]]

        self.odps.write_table(table2, 0, data2)

        expr = self.expr.join(expr2, on='name')[self.expr]
        expr = expr.groupby('id').agg(expr.fid.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        id_idx = [
            idx for idx, col in enumerate(self.expr.schema.names)
            if col == 'id'
        ][0]
        fid_idx = [
            idx for idx, col in enumerate(self.expr.schema.names)
            if col == 'fid'
        ][0]
        expected = [[k, sum(
            v[fid_idx] for v in row)] for k, row in itertools.groupby(
                sorted(data, key=lambda r: r[id_idx]), lambda r: r[id_idx])]
        for it in zip(sorted(expected, key=lambda it: it[0]),
                      sorted(result, key=lambda it: it[0])):
            self.assertAlmostEqual(it[0][0], it[1][0])
            self.assertAlmostEqual(it[0][1], it[1][1])
Ejemplo n.º 27
0
    def testApplyMap(self):
        from odps.df.expr.collections import ProjectCollectionExpr, Column
        from odps.df.expr.element import MappedExpr

        schema = Schema.from_lists(['idx', 'f1', 'f2', 'f3'],
                                   [types.int64] + [types.float64] * 3)
        expr = CollectionExpr(_source_data=None, _schema=schema)

        self.assertRaises(
            ValueError, lambda: expr.applymap(
                lambda v: v + 1, columns='idx', excludes='f1'))

        mapped = expr.applymap(lambda v: v + 1)
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c, MappedExpr)

        mapped = expr.applymap(lambda v: v + 1, columns='f1')
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c, MappedExpr if c.name == 'f1' else Column)

        map_cols = set(['f1', 'f2', 'f3'])
        mapped = expr.applymap(lambda v: v + 1, columns=map_cols)
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c,
                                  MappedExpr if c.name in map_cols else Column)

        mapped = expr.applymap(lambda v: v + 1, excludes='idx')
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c, Column if c.name == 'idx' else MappedExpr)

        exc_cols = set(['idx', 'f1'])
        mapped = expr.applymap(lambda v: v + 1, excludes=exc_cols)
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c,
                                  Column if c.name in exc_cols else MappedExpr)
Ejemplo n.º 28
0
    def testJoinGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        data2 = [['name1', 4, -1], ['name2', 1, -2]]

        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    datatypes('string', 'int64', 'int64'))
        table_name = tn('pyodps_test_engine_table2')
        table2 = self._create_table_and_insert_data(table_name, schema2, data2)
        expr2 = CollectionExpr(_source_data=table2, _schema=schema2)

        self._gen_data(data=data)

        expr = self.expr.join(expr2, on='name')[self.expr]
        expr = expr.groupby('id').agg(expr.fid.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        id_idx = [
            idx for idx, col in enumerate(self.expr.schema.names)
            if col == 'id'
        ][0]
        fid_idx = [
            idx for idx, col in enumerate(self.expr.schema.names)
            if col == 'fid'
        ][0]
        expected = [[k, sum(
            v[fid_idx] for v in row)] for k, row in itertools.groupby(
                sorted(data, key=lambda r: r[id_idx]), lambda r: r[id_idx])]
        for it in zip(sorted(expected, key=lambda it: it[0]),
                      sorted(result, key=lambda it: it[0])):
            self.assertAlmostEqual(it[0][0], it[1][0])
            self.assertAlmostEqual(it[0][1], it[1][1])
Ejemplo n.º 29
0
    def testJoinGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])

        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2,
                               _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [['name1', 4, -1], ['name2', 1, -2]]

        self.odps.write_table(table2, 0,
                              [table2.new_record(values=d) for d in data2])

        expr = self.expr.join(expr2, on='name')[self.expr]
        expr = expr.groupby('id').agg(expr.fid.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        import pandas as pd
        expected = pd.DataFrame(data, columns=self.expr.schema.names).groupby('id').agg({'fid': 'sum'})\
            .reset_index().values.tolist()
        for it in zip(sorted(expected, key=lambda it: it[0]),
                      sorted(result, key=lambda it: it[0])):
            self.assertAlmostEqual(it[0][0], it[1][0])
            self.assertAlmostEqual(it[0][1], it[1][1])
    def testFilterOrder(self):
        table_name = tn('pyodps_test_division_error')
        self.odps.delete_table(table_name, if_exists=True)
        table = self.odps.create_table(table_name,
                                       'divided bigint, divisor bigint',
                                       lifecycle=1)

        try:
            self.odps.write_table(table_name,
                                  [[2, 0], [1, 1], [1, 2], [5, 1], [5, 0]])
            df = CollectionExpr(_source_data=table,
                                _schema=odps_schema_to_df_schema(table.schema))
            fdf = df[df.divisor > 0]
            ddf = fdf[(fdf.divided / fdf.divisor).rename('result'), ]
            expr = ddf[ddf.result > 1]

            res = self.engine.execute(expr)
            result = self._get_result(res)
            self.assertEqual(result, [[
                5,
            ]])
        finally:
            table.drop()
Ejemplo n.º 31
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
            datatypes('string', 'int64', 'float64', 'boolean', 'decimal',
                      'datetime'), ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table,
                                   _schema=Schema(columns=schema.columns))

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1,
                                    _schema=Schema(columns=schema.columns))

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2,
                                    _schema=Schema(columns=schema.columns))

        schema2 = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'],
                                    datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3,
                                    _schema=Schema(columns=schema2.columns))

        schema3 = Schema.from_lists(['id', 'name', 'relatives', 'hobbies'],
                                    datatypes('int64', 'string',
                                              'dict<string, string>',
                                              'list<string>'))
        table4 = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr4 = CollectionExpr(_source_data=table4, _schema=schema3)

    def testProjectPrune(self):
        expr = self.expr.select('name', 'id')
        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(expected,
                         ODPSEngine(self.odps).compile(expr, prettify=False))

        expr = self.expr[Scalar(3).rename('const'),
                         NullScalar('string').rename('string_const'),
                         self.expr.id]
        expected = 'SELECT 3 AS `const`, CAST(NULL AS STRING) AS `string_const`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.select(
            pt=BuiltinFunction('max_pt', args=(self.expr._source_data.name, )))
        expected = "SELECT max_pt('pyodps_test_expr_table') AS `pt` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testApplyPrune(self):
        @output(['name', 'id'], ['string', 'string'])
        def h(row):
            yield row[0], row[1]

        expr = self.expr[self.expr.fid < 0].apply(h, axis=1)['id', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsInstance(new_expr.input.input, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input.input.input._source_data)

    def testFilterPrune(self):
        expr = self.expr.filter(self.expr.name == 'name1')
        expr = expr['name', 'id']

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input, FilterCollectionExpr)
        self.assertNotIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`name` == \'name1\''
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1')

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expr = self.expr.filter(self.expr.id.isin(self.expr3.id))

        expected = 'SELECT * \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`id` IN (SELECT t3.`id` FROM (  ' \
                   'SELECT t2.`id`   FROM mocked_project.`pyodps_test_expr_table2` t2 ) t3)'
        self.assertTrue(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPartsPrune(self):
        expr = self.expr.filter_parts('ds=today')[lambda x: x.fid < 0][
            'name', lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertEqual(set(new_expr.input.input.schema.names),
                         set(['name', 'id', 'fid']))

        expected = "SELECT t2.`name`, t2.`id` + 1 AS `id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id`, t1.`fid` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE t1.`ds` == 'today' \n" \
                   ") t2 \n" \
                   "WHERE t2.`fid` < 0"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSlicePrune(self):
        expr = self.expr.filter(self.expr.fid < 0)[:4]['name',
                                                       lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsNotNone(new_expr.input.input.input._source_data)

        expected = "SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE t1.`fid` < 0 \n" \
                   "LIMIT 4"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testGroupbyPrune(self):
        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['name', ]

        expected = "SELECT t1.`name` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['id', ]

        expected = "SELECT MAX(t1.`id`) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testMutatePrune(self):
        expr = self.expr[self.expr.exclude('birth'),
                         self.expr.fid.astype('int').rename('new_id')]
        expr = expr[expr,
                    expr.groupby('name').
                    mutate(lambda x: x.new_id.cumsum().rename('new_id_sum'))]
        expr = expr[expr.new_id, expr.new_id_sum]

        expected = "SELECT t2.`new_id`, t2.`new_id_sum` \n" \
                   "FROM (\n" \
                   "  SELECT CAST(t1.`fid` AS BIGINT) AS `new_id`, " \
                   "SUM(CAST(t1.`fid` AS BIGINT)) OVER (PARTITION BY t1.`name`) AS `new_id_sum` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testValueCountsPrune(self):
        expr = self.expr.name.value_counts()['count', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertEqual(set(new_expr.input.input.schema.names), set(['name']))

    def testSortPrune(self):
        expr = self.expr[self.expr.exclude('name'),
                         self.expr.name.rename('name2')].sort('name2')['id',
                                                                       'fid']

        expected = "SELECT t2.`id`, t2.`fid` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id`, t1.`fid`, t1.`name` AS `name2` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  ORDER BY name2 \n" \
                   "  LIMIT 10000\n" \
                   ") t2"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testDistinctPrune(self):
        expr = self.expr.distinct(self.expr.id + 1, self.expr.name)['name', ]

        expected = "SELECT t2.`name` \n" \
                   "FROM (\n" \
                   "  SELECT DISTINCT t1.`id` + 1 AS `id`, t1.`name` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSamplePrune(self):
        expr = self.expr['name', 'id'].sample(parts=5)['id', ]

        expected = "SELECT t1.`id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE SAMPLE(5, 1)"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testJoinPrune(self):
        left = self.expr.select(self.expr, type='normal')
        right = self.expr3[:4]
        joined = left.left_join(right, on='id')
        expr = joined.id_x.rename('id')

        expected = "SELECT t2.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1\n" \
                   ") t2 \n" \
                   "LEFT OUTER JOIN \n" \
                   "  (\n" \
                   "    SELECT t3.`id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    LIMIT 4\n" \
                   "  ) t4\n" \
                   "ON t2.`id` == t4.`id`"

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        joined = self.expr.join(self.expr2, 'name')

        expected = 'SELECT t1.`name`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, ' \
                   't1.`isMale` AS `isMale_x`, t1.`scale` AS `scale_x`, ' \
                   't1.`birth` AS `birth_x`, t1.`ds` AS `ds_x`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(joined, prettify=False)))

        joined = self.expr.join(self.expr2,
                                on=[self.expr.name == self.expr2.name])
        joined2 = joined.join(self.expr, on=[joined.id_x == self.expr.id])

        expected = 'SELECT t1.`name` AS `name_x`, t1.`id` AS `id_x`, ' \
                   't1.`fid` AS `fid_x`, t1.`isMale` AS `isMale_x`, ' \
                   't1.`scale` AS `scale_x`, t1.`birth` AS `birth_x`, ' \
                   't1.`ds` AS `ds_x`, t2.`id` AS `id_y`, t2.`fid` AS `fid_y`, ' \
                   't2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y`, ' \
                   't3.`name` AS `name_y`, t3.`id`, t3.`fid`, t3.`isMale`, ' \
                   't3.`scale`, t3.`birth`, t3.`ds` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name` \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table` t3\n' \
                   'ON t1.`id` == t3.`id`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(joined2, prettify=False)))

        joined = self.expr.join(self.expr2,
                                on=[self.expr.name == self.expr2.name],
                                mapjoin=True)
        joined2 = joined.join(self.expr,
                              on=[joined.id_x == self.expr.id],
                              mapjoin=True)

        expected = 'SELECT /*+mapjoin(t2,t3)*/ t1.`name` AS `name_x`, t1.`id` AS `id_x`, ' \
                   't1.`fid` AS `fid_x`, t1.`isMale` AS `isMale_x`, t1.`scale` AS `scale_x`, ' \
                   't1.`birth` AS `birth_x`, t1.`ds` AS `ds_x`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y`, t3.`name` AS `name_y`, ' \
                   't3.`id`, t3.`fid`, t3.`isMale`, t3.`scale`, t3.`birth`, t3.`ds` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name` \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table` t3\n' \
                   'ON t1.`id` == t3.`id`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(joined2, prettify=False)))

    def testUnionPrune(self):
        left = self.expr.select('name', 'id')
        right = self.expr3.select(
            self.expr3.fid.astype('int').rename('id'), self.expr3.name)
        expr = left.union(right)['id']

        expected = "SELECT t3.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  UNION ALL\n" \
                   "    SELECT CAST(t2.`fid` AS BIGINT) AS `id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t2\n" \
                   ") t3"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.union(self.expr2)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  UNION ALL\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2\n' \
                   ') t3'

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testLateralViewPrune(self):
        expr = self.expr4['name', 'id', self.expr4.hobbies.explode()]
        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsInstance(new_expr, LateralViewCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id`, t2.`hobbies` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'LATERAL VIEW EXPLODE(t1.`hobbies`) t2 AS `hobbies`'
        self.assertEqual(expected,
                         ODPSEngine(self.odps).compile(expr, prettify=False))

        expected = 'SELECT t1.`id`, t2.`hobbies` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'LATERAL VIEW EXPLODE(t1.`hobbies`) t2 AS `hobbies`'

        expr2 = expr[expr.id, expr.hobbies]
        self.assertEqual(expected,
                         ODPSEngine(self.odps).compile(expr2, prettify=False))
Ejemplo n.º 32
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'))
        self.schema = df_schema_to_odps_schema(schema)
        table_name = 'pyodps_test_engine_table'
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(
                name='pyodps_test_engine_table', schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        self.engine = ODPSEngine(self.odps)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass
        self.faked_bar = FakeBar()

    def _gen_data(self, rows=None, data=None, nullable_field=None, value_range=None):
        if data is None:
            data = []
            for _ in range(rows):
                record = []
                for t in self.schema.types:
                    method = getattr(self, '_gen_random_%s' % t.name)
                    if t.name == 'bigint':
                        record.append(method(value_range=value_range))
                    else:
                        record.append(method())
                data.append(record)

            if nullable_field is not None:
                j = self.schema._name_indexes[nullable_field]
                for i, l in enumerate(data):
                    if i % 2 == 0:
                        data[i][j] = None

        self.odps.write_table(self.table, 0, [self.table.new_record(values=d) for d in data])
        return data

    def testTunnelCases(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr.count()
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(10, result)

        expr = self.expr.name.count()
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(10, result)

        res = self.engine._handle_cases(self.expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(data, result)

        expr = self.expr['name', self.expr.id.rename('new_id')]
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual([it[:2] for it in data], result)

        table_name = 'pyodps_test_engine_partitioned'
        self.odps.delete_table(table_name, if_exists=True)

        df = self.engine.persist(self.expr, table_name, partitions=['name'])

        try:
            expr = df.count()
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertIsNone(res)

            expr = df[df.name == data[0][0]]['fid', 'id'].count()
            expr = self.engine._pre_process(expr)
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(res, 0)

            expr = df[df.name == data[0][0]]['fid', 'id']
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(len(res), 0)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        df = self.engine.persist(self.expr, table_name, partitions=['name', 'id'])

        try:
            expr = df.count()
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertIsNone(res)

            expr = df[(df.name == data[0][0]) & (df.id == data[0][1])]['fid', 'ismale'].count()
            expr = self.engine._pre_process(expr)
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(res, 0)

            expr = df[(df.name == data[0][0]) & (df.id == data[0][1])]['fid', 'ismale']
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(len(res), 0)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

    def testAsync(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr.id.sum()

        res = self.engine.execute(expr, async=True)
        self.assertNotEqual(res.instance.status, Instance.Status.TERMINATED)
        res.wait()

        self.assertEqual(sum(it[1] for it in data), res.fetch())

    def testBase(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr[self.expr.id < 10]['name', lambda x: x.id]
        result = self._get_result(self.engine.execute(expr).values)
        self.assertEqual(len([it for it in data if it[1] < 10]), len(result))
        if len(result) > 0:
            self.assertEqual(2, len(result[0]))

        expr = self.expr[Scalar(3).rename('const'), self.expr.id, (self.expr.id + 1).rename('id2')]
        res = self.engine.execute(expr)
        result = self._get_result(res.values)
        self.assertEqual([c.name for c in res.columns], ['const', 'id', 'id2'])
        self.assertTrue(all(it[0] == 3 for it in result))
        self.assertEqual(len(data), len(result))
        self.assertEqual([it[1]+1 for it in data], [it[2] for it in result])

        expr = self.expr.sort('id')[:5]
        res = self.engine.execute(expr)
        result = self._get_result(res.values)
        self.assertEqual(sorted(data, key=lambda it: it[1])[:5], result)

        expr = self.expr.sort('id')[:5]
        # test do not use tunnel
        res = self.engine.execute(expr, use_tunnel=False)
        result = self._get_result(res.values)
        self.assertEqual(sorted(data, key=lambda it: it[1])[:5], result)

    def testElement(self):
        data = self._gen_data(5, nullable_field='name')

        fields = [
            self.expr.name.isnull().rename('name1'),
            self.expr.name.notnull().rename('name2'),
            self.expr.name.fillna('test').rename('name3'),
            self.expr.id.isin([1, 2, 3]).rename('id1'),
            self.expr.id.isin(self.expr.fid.astype('int')).rename('id2'),
            self.expr.id.notin([1, 2, 3]).rename('id3'),
            self.expr.id.notin(self.expr.fid.astype('int')).rename('id4'),
            self.expr.id.between(self.expr.fid, 3).rename('id5'),
            self.expr.name.fillna('test').switch('test', 'test' + self.expr.name.fillna('test'),
                                                 'test2', 'test2' + self.expr.name.fillna('test'),
                                                 default=self.expr.name).rename('name4'),
            self.expr.id.cut([100, 200, 300],
                             labels=['xsmall', 'small', 'large', 'xlarge'],
                             include_under=True, include_over=True).rename('id6')
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(data), len(result))

        self.assertEqual(len([it for it in data if it[0] is None]),
                         len([it[0] for it in result if it[0]]))

        self.assertEqual(len([it[0] for it in data if it[0] is not None]),
                         len([it[1] for it in result if it[1]]))

        self.assertEqual([(it[0] if it[0] is not None else 'test') for it in data],
                         [it[2] for it in result])

        self.assertEqual([(it[1] in (1, 2, 3)) for it in data],
                         [it[3] for it in result])

        fids = [int(it[2]) for it in data]
        self.assertEqual([(it[1] in fids) for it in data],
                         [it[4] for it in result])

        self.assertEqual([(it[1] not in (1, 2, 3)) for it in data],
                         [it[5] for it in result])

        self.assertEqual([(it[1] not in fids) for it in data],
                         [it[6] for it in result])

        self.assertEqual([(it[2] <= it[1] <= 3) for it in data],
                         [it[7] for it in result])

        self.assertEqual([to_str('testtest' if it[0] is None else it[0]) for it in data],
                         [to_str(it[8]) for it in result])

        def get_val(val):
            if val <= 100:
                return 'xsmall'
            elif 100 < val <= 200:
                return 'small'
            elif 200 < val <= 300:
                return 'large'
            else:
                return 'xlarge'
        self.assertEqual([to_str(get_val(it[1])) for it in data], [to_str(it[9]) for it in result])

    def testArithmetic(self):
        data = self._gen_data(5, value_range=(-1000, 1000))

        fields = [
            (self.expr.id + 1).rename('id1'),
            (self.expr.fid - 1).rename('fid1'),
            (self.expr.scale * 2).rename('scale1'),
            (self.expr.scale + self.expr.id).rename('scale2'),
            (self.expr.id / 2).rename('id2'),
            (self.expr.id ** -2).rename('id3'),
            abs(self.expr.id).rename('id4'),
            (~self.expr.id).rename('id5'),
            (-self.expr.fid).rename('fid2'),
            (~self.expr.isMale).rename('isMale1'),
            (-self.expr.isMale).rename('isMale2'),
            (self.expr.id // 2).rename('id6'),
            (self.expr.birth + day(1).rename('birth1')),
            (self.expr.birth - (self.expr.birth - millisecond(10))).rename('birth2'),
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(data), len(result))

        self.assertEqual([it[1] + 1 for it in data],
                         [it[0] for it in result])

        self.assertAlmostEqual([it[2] - 1 for it in data],
                               [it[1] for it in result])

        self.assertEqual([it[4] * 2 for it in data],
                         [it[2] for it in result])

        self.assertEqual([it[4] + it[1] for it in data],
                         [it[3] for it in result])

        self.assertAlmostEqual([float(it[1]) / 2 for it in data],
                               [it[4] for it in result])

        self.assertEqual([int(it[1] ** -2) for it in data],
                         [it[5] for it in result])

        self.assertEqual([abs(it[1]) for it in data],
                         [it[6] for it in result])

        self.assertEqual([~it[1] for it in data],
                         [it[7] for it in result])

        self.assertAlmostEqual([-it[2] for it in data],
                               [it[8] for it in result])

        self.assertEqual([not it[3] for it in data],
                         [it[9] for it in result])

        self.assertEqual([it[1] // 2 for it in data],
                         [it[11] for it in result])

        self.assertEqual([it[5] + timedelta(days=1) for it in data],
                         [it[12] for it in result])

        self.assertEqual([10] * len(data), [it[13] for it in result])

    def testMath(self):
        data = self._gen_data(5, value_range=(1, 90))

        import numpy as np

        methods_to_fields = [
            (np.sin, self.expr.id.sin()),
            (np.cos, self.expr.id.cos()),
            (np.tan, self.expr.id.tan()),
            (np.sinh, self.expr.id.sinh()),
            (np.cosh, self.expr.id.cosh()),
            (np.tanh, self.expr.id.tanh()),
            (np.log, self.expr.id.log()),
            (np.log2, self.expr.id.log2()),
            (np.log10, self.expr.id.log10()),
            (np.log1p, self.expr.id.log1p()),
            (np.exp, self.expr.id.exp()),
            (np.expm1, self.expr.id.expm1()),
            (np.arccosh, self.expr.id.arccosh()),
            (np.arcsinh, self.expr.id.arcsinh()),
            (np.arctanh, self.expr.id.arctanh()),
            (np.arctan, self.expr.id.arctan()),
            (np.sqrt, self.expr.id.sqrt()),
            (np.abs, self.expr.id.abs()),
            (np.ceil, self.expr.id.ceil()),
            (np.floor, self.expr.id.floor()),
            (np.trunc, self.expr.id.trunc()),
        ]

        fields = [it[1].rename('id'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = [method(it[1]) for it in data]
            second = [it[i] for it in result]
            self.assertEqual(len(first), len(second))
            for it1, it2 in zip(first, second):
                if np.isnan(it1) and np.isnan(it2):
                    continue
                self.assertAlmostEqual(it1, it2)

    def testString(self):
        data = self._gen_data(5)

        methods_to_fields = [
            (lambda s: s.capitalize(), self.expr.name.capitalize()),
            (lambda s: data[0][0] in s, self.expr.name.contains(data[0][0], regex=False)),
            (lambda s: s.count(data[0][0]), self.expr.name.count(data[0][0])),
            (lambda s: s.endswith(data[0][0]), self.expr.name.endswith(data[0][0])),
            (lambda s: s.startswith(data[0][0]), self.expr.name.startswith(data[0][0])),
            (lambda s: s.find(data[0][0]), self.expr.name.find(data[0][0])),
            (lambda s: s.rfind(data[0][0]), self.expr.name.rfind(data[0][0])),
            (lambda s: s.replace(data[0][0], 'test'), self.expr.name.replace(data[0][0], 'test')),
            (lambda s: s[0], self.expr.name.get(0)),
            (lambda s: len(s), self.expr.name.len()),
            (lambda s: s.ljust(10), self.expr.name.ljust(10)),
            (lambda s: s.ljust(20, '*'), self.expr.name.ljust(20, fillchar='*')),
            (lambda s: s.rjust(10), self.expr.name.rjust(10)),
            (lambda s: s.rjust(20, '*'), self.expr.name.rjust(20, fillchar='*')),
            (lambda s: s * 4, self.expr.name.repeat(4)),
            (lambda s: s[2: 10: 2], self.expr.name.slice(2, 10, 2)),
            (lambda s: s[-5: -1], self.expr.name.slice(-5, -1)),
            (lambda s: s.title(), self.expr.name.title()),
            (lambda s: s.rjust(20, '0'), self.expr.name.zfill(20)),
            (lambda s: s.isalnum(), self.expr.name.isalnum()),
            (lambda s: s.isalpha(), self.expr.name.isalpha()),
            (lambda s: s.isdigit(), self.expr.name.isdigit()),
            (lambda s: s.isspace(), self.expr.name.isspace()),
            (lambda s: s.isupper(), self.expr.name.isupper()),
            (lambda s: s.istitle(), self.expr.name.istitle()),
            (lambda s: to_str(s).isnumeric(), self.expr.name.isnumeric()),
            (lambda s: to_str(s).isdecimal(), self.expr.name.isdecimal()),
        ]

        fields = [it[1].rename('id'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = [method(it[0]) for it in data]
            second = [it[i] for it in result]
            self.assertEqual(first, second)

    def testApply(self):
        data = self._gen_data(5)

        def my_func(row):
            return row.name,

        expr = self.expr['name', 'id'].apply(my_func, axis=1, names='name')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual([r[0] for r in result], [r[0] for r in data])

        def my_func2(row):
            yield len(row.name)
            yield row.id

        expr = self.expr['name', 'id'].apply(my_func2, axis=1, names='cnt', types='int')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        def gen_expected(data):
            for r in data:
                yield len(r[0])
                yield r[1]

        self.assertEqual([r[0] for r in result], [r for r in gen_expected(data)])

    def testDatetime(self):
        data = self._gen_data(5)

        import pandas as pd

        methods_to_fields = [
            (lambda s: list(s.birth.dt.year.values), self.expr.birth.year),
            (lambda s: list(s.birth.dt.month.values), self.expr.birth.month),
            (lambda s: list(s.birth.dt.day.values), self.expr.birth.day),
            (lambda s: list(s.birth.dt.hour.values), self.expr.birth.hour),
            (lambda s: list(s.birth.dt.minute.values), self.expr.birth.minute),
            (lambda s: list(s.birth.dt.second.values), self.expr.birth.second),
            (lambda s: list(s.birth.dt.weekofyear.values), self.expr.birth.weekofyear),
            (lambda s: list(s.birth.dt.dayofweek.values), self.expr.birth.dayofweek),
            (lambda s: list(s.birth.dt.weekday.values), self.expr.birth.weekday),
            (lambda s: list(s.birth.dt.date.values), self.expr.birth.date),
            (lambda s: list(s.birth.dt.strftime('%Y%d')), self.expr.birth.strftime('%Y%d'))
        ]

        fields = [it[1].rename('birth'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        df = pd.DataFrame(data, columns=self.schema.names)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = method(df)

            def conv(v):
                if isinstance(v, pd.Timestamp):
                    return v.to_datetime().date()
                else:
                    return v

            second = [conv(it[i]) for it in result]
            self.assertEqual(first, second)

    def testSortDistinct(self):
        data = [
            ['name1', 4, None, None, None, None],
            ['name2', 2, None, None, None, None],
            ['name1', 4, None, None, None, None],
            ['name1', 3, None, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.sort(['name', -self.expr.id]).distinct(['name', lambda x: x.id + 1])[:50]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(result), 3)

        expected = [
            ['name1', 5],
            ['name1', 4],
            ['name2', 3]
        ]
        self.assertEqual(expected, result)

    def testGroupbyAggregation(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.groupby(['name', 'id'])[lambda x: x.fid.min() * 2 < 8] \
            .agg(self.expr.fid.max() + 1, new_id=self.expr.id.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [
            ['name1', 3, 5.1, 6],
            ['name2', 2, 4.5, 2]
        ]

        result = sorted(result, key=lambda k: k[0])

        self.assertEqual(expected, result)

        field = self.expr.groupby('name').sort(['id', -self.expr.fid]).row_number()
        expr = self.expr['name', 'id', 'fid', field]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [
            ['name1', 3, 4.1, 1],
            ['name1', 3, 2.2, 2],
            ['name1', 4, 5.3, 3],
            ['name1', 4, 4.2, 4],
            ['name2', 2, 3.5, 1],
        ]

        result = sorted(result, key=lambda k: (k[0], k[1], -k[2]))

        self.assertEqual(expected, result)

        expr = self.expr.name.value_counts()[:25]

        expected = [
            ['name1', 4],
            ['name2', 1]
        ]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

        expr = self.expr.name.topk(25)

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

        expr = self.expr.groupby('name').count()

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual([it[1:] for it in expected], result)

        expected = [
            ['name1', 2],
            ['name2', 1]
        ]

        expr = self.expr.groupby('name').id.nunique()

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual([it[1:] for it in expected], result)

        expr = self.expr[self.expr['id'] > 2].name.value_counts()[:25]

        expected = [
            ['name1', 4]
        ]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

    def testJoinGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])

        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name1', 4, -1],
            ['name2', 1, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        expr = self.expr.join(expr2, on='name')[self.expr]
        expr = expr.groupby('id').agg(expr.fid.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        import pandas as pd
        expected = pd.DataFrame(data, columns=self.expr.schema.names).groupby('id').agg({'fid': 'sum'})\
            .reset_index().values.tolist()
        for it in zip(sorted(expected, key=lambda it: it[0]), sorted(result, key=lambda it: it[0])):
            self.assertAlmostEqual(it[0][0], it[1][0])
            self.assertAlmostEqual(it[0][1], it[1][1])

    def testFilterGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.groupby(['name']).agg(id=self.expr.id.max())[lambda x: x.id > 3]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(result), 1)

        expected = [
            ['name1', 4]
        ]

        self.assertEqual(expected, result)

    def testWindowRewrite(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr[self.expr.id - self.expr.id.mean() < 10][
            [lambda x: x.id - x.id.max()]][[lambda x: x.id - x.id.min()]][lambda x: x.id - x.id.std() > 0]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        import pandas as pd
        df = pd.DataFrame(data, columns=self.schema.names)
        expected = df.id - df.id.max()
        expected = expected - expected.min()
        expected = list(expected[expected - expected.std() > 0])

        self.assertEqual(expected, [it[0] for it in result])

    def testReduction(self):
        data = self._gen_data(rows=5, value_range=(-100, 100))

        import pandas as pd
        df = pd.DataFrame(data, columns=self.schema.names)

        methods_to_fields = [
            (lambda s: df.id.mean(), self.expr.id.mean()),
            (lambda s: len(df), self.expr.count()),
            (lambda s: df.id.var(ddof=0), self.expr.id.var(ddof=0)),
            (lambda s: df.id.std(ddof=0), self.expr.id.std(ddof=0)),
            (lambda s: df.id.median(), self.expr.id.median()),
            (lambda s: df.id.sum(), self.expr.id.sum()),
            (lambda s: df.id.min(), self.expr.id.min()),
            (lambda s: df.id.max(), self.expr.id.max()),
            (lambda s: df.isMale.min(), self.expr.isMale.min()),
            (lambda s: df.name.max(), self.expr.name.max()),
            (lambda s: df.birth.max(), self.expr.birth.max()),
            (lambda s: df.isMale.sum(), self.expr.isMale.sum()),
            (lambda s: df.isMale.any(), self.expr.isMale.any()),
            (lambda s: df.isMale.all(), self.expr.isMale.all()),
            (lambda s: df.name.nunique(), self.expr.name.nunique()),
        ]

        fields = [it[1].rename('f'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        df = pd.DataFrame(data, columns=self.schema.names)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = method(df)
            second = [it[i] for it in result][0]
            if isinstance(first, float):
                self.assertAlmostEqual(first, second)
            else:
                self.assertEqual(first, second)

    def testMapReduceByApplyDistributeSort(self):
        data = [
            ['name key', 4, 5.3, None, None, None],
            ['name', 2, 3.5, None, None, None],
            ['key', 4, 4.2, None, None, None],
            ['name', 3, 2.2, None, None, None],
            ['key name', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        def mapper(row):
            for word in row[0].split():
                yield word, 1

        class reducer(object):
            def __init__(self):
                self._curr = None
                self._cnt = 0

            def __call__(self, row):
                if self._curr is None:
                    self._curr = row.word
                elif self._curr != row.word:
                    yield (self._curr, self._cnt)
                    self._curr = row.word
                    self._cnt = 0
                self._cnt += row.count

            def close(self):
                if self._curr is not None:
                    yield (self._curr, self._cnt)

        expr = self.expr['name', ].apply(
            mapper, axis=1, names=['word', 'count'], types=['string', 'int'])
        expr = expr.groupby('word').sort('word').apply(
            reducer, names=['word', 'count'], types=['string', 'int'])

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 3], ['name', 4]]
        self.assertEqual(sorted(result), sorted(expected))

    def testMapReduce(self):
        data = [
            ['name key', 4, 5.3, None, None, None],
            ['name', 2, 3.5, None, None, None],
            ['key', 4, 4.2, None, None, None],
            ['name', 3, 2.2, None, None, None],
            ['key name', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        @output(['word', 'cnt'], ['string', 'int'])
        def mapper(row):
            for word in row[0].split():
                yield word, 1

        @output(['word', 'cnt'], ['string', 'int'])
        def reducer(keys):
            cnt = [0, ]

            def h(row, done):
                cnt[0] += row[1]
                if done:
                    yield keys[0], cnt[0]

            return h

        expr = self.expr['name', ].map_reduce(mapper, reducer, group='word')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 3], ['name', 4]]
        self.assertEqual(sorted(result), sorted(expected))

        @output(['word', 'cnt'], ['string', 'int'])
        class reducer2(object):
            def __init__(self, keys):
                self.cnt = 0

            def __call__(self, row, done):
                self.cnt += row.cnt
                if done:
                    yield row.word, self.cnt

        expr = self.expr['name', ].map_reduce(mapper, reducer2, group='word')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 3], ['name', 4]]
        self.assertEqual(sorted(result), sorted(expected))

    def testDistributeSort(self):
        data = [
            ['name', 4, 5.3, None, None, None],
            ['name', 2, 3.5, None, None, None],
            ['key', 4, 4.2, None, None, None],
            ['name', 3, 2.2, None, None, None],
            ['key', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        @output_names('name', 'id')
        @output_types('string', 'int')
        class reducer(object):
            def __init__(self):
                self._curr = None
                self._cnt = 0

            def __call__(self, row):
                if self._curr is None:
                    self._curr = row.name
                elif self._curr != row.name:
                    yield (self._curr, self._cnt)
                    self._curr = row.name
                    self._cnt = 0
                self._cnt += 1

            def close(self):
                if self._curr is not None:
                    yield (self._curr, self._cnt)

        expr = self.expr['name', ].groupby('name').sort('name').apply(reducer)

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 2], ['name', 3]]
        self.assertEqual(sorted(expected), sorted(result))

    def testJoin(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name1', 4, -1],
            ['name2', 1, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr.join(expr2)['name', 'id2']

            res = self.engine.execute(expr)
            result = self._get_result(res)

            self.assertEqual(len(result), 5)
            expected = [
                [to_str('name1'), 4],
                [to_str('name2'), 1]
            ]
            self.assertTrue(all(it in expected for it in result))

            expr = self.expr.join(expr2, on=['name', ('id', 'id2')])[self.expr.name, expr2.id2]
            res = self.engine.execute(expr)
            result = self._get_result(res)
            self.assertEqual(len(result), 2)
            expected = [to_str('name1'), 4]
            self.assertTrue(all(it == expected for it in result))

        finally:
            table2.drop()

    def testUnion(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name3', 5, -1],
            ['name4', 6, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr['name', 'id'].distinct().union(expr2[expr2.id2.rename('id'), 'name'])

            res = self.engine.execute(expr)
            result = self._get_result(res)

            expected = [
                ['name1', 4],
                ['name1', 3],
                ['name2', 2],
                ['name3', 5],
                ['name4', 6]
            ]

            result = sorted(result)
            expected = sorted(expected)

            self.assertEqual(len(result), len(expected))
            for e, r in zip(result, expected):
                self.assertEqual([to_str(t) for t in e],
                                 [to_str(t) for t in r])

        finally:
            table2.drop()

    def testPersist(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        table_name = 'pyodps_test_engine_persist_table'

        try:
            df = self.engine.persist(self.expr, table_name)

            res = self.engine.execute(df)
            result = self._get_result(res)
            self.assertEqual(len(result), 5)
            self.assertEqual(data, result)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        try:
            schema = Schema.from_lists(self.schema.names, self.schema.types, ['ds'], ['string'])
            self.odps.create_table(table_name, schema)
            df = self.engine.persist(self.expr, table_name, partition='ds=today', create_partition=True)

            res = self.engine.execute(df)
            result = self._get_result(res)
            self.assertEqual(len(result), 5)
            self.assertEqual(data, [d[:-1] for d in result])
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        try:
            self.engine.persist(self.expr, table_name, partitions=['name'])

            t = self.odps.get_table(table_name)
            self.assertEqual(2, len(list(t.partitions)))
            with t.open_reader(partition='name=name1', reopen=True) as r:
                self.assertEqual(4, r.count)
            with t.open_reader(partition='name=name2', reopen=True) as r:
                self.assertEqual(1, r.count)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

    def teardown(self):
        self.table.drop()
Ejemplo n.º 33
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
            datatypes('string', 'int64', 'float64', 'boolean', 'decimal',
                      'datetime'))
        self.schema = df_schema_to_odps_schema(schema)
        table_name = 'pyodps_test_engine_table'
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(name='pyodps_test_engine_table',
                                            schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        self.engine = ODPSEngine(self.odps)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass

        self.faked_bar = FakeBar()

    def _gen_data(self,
                  rows=None,
                  data=None,
                  nullable_field=None,
                  value_range=None):
        if data is None:
            data = []
            for _ in range(rows):
                record = []
                for t in self.schema.types:
                    method = getattr(self, '_gen_random_%s' % t.name)
                    if t.name == 'bigint':
                        record.append(method(value_range=value_range))
                    else:
                        record.append(method())
                data.append(record)

            if nullable_field is not None:
                j = self.schema._name_indexes[nullable_field]
                for i, l in enumerate(data):
                    if i % 2 == 0:
                        data[i][j] = None

        self.odps.write_table(self.table, 0,
                              [self.table.new_record(values=d) for d in data])
        return data

    def testTunnelCases(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr.count()
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(10, result)

        expr = self.expr.name.count()
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(10, result)

        res = self.engine._handle_cases(self.expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(data, result)

        expr = self.expr['name', self.expr.id.rename('new_id')]
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual([it[:2] for it in data], result)

        table_name = 'pyodps_test_engine_partitioned'
        self.odps.delete_table(table_name, if_exists=True)

        df = self.engine.persist(self.expr, table_name, partitions=['name'])

        try:
            expr = df.count()
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertIsNone(res)

            expr = df[df.name == data[0][0]]['fid', 'id'].count()
            expr = self.engine._pre_process(expr)
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(res, 0)

            expr = df[df.name == data[0][0]]['fid', 'id']
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(len(res), 0)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        df = self.engine.persist(self.expr,
                                 table_name,
                                 partitions=['name', 'id'])

        try:
            expr = df.count()
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertIsNone(res)

            expr = df[(df.name == data[0][0])
                      & (df.id == data[0][1])]['fid', 'ismale'].count()
            expr = self.engine._pre_process(expr)
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(res, 0)

            expr = df[(df.name == data[0][0])
                      & (df.id == data[0][1])]['fid', 'ismale']
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(len(res), 0)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

    def testAsync(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr.id.sum()

        res = self.engine.execute(expr, async=True)
        self.assertNotEqual(res.instance.status, Instance.Status.TERMINATED)
        res.wait()

        self.assertEqual(sum(it[1] for it in data), res.fetch())

    def testBase(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr[self.expr.id < 10]['name', lambda x: x.id]
        result = self._get_result(self.engine.execute(expr).values)
        self.assertEqual(len([it for it in data if it[1] < 10]), len(result))
        if len(result) > 0:
            self.assertEqual(2, len(result[0]))

        expr = self.expr[Scalar(3).rename('const'), self.expr.id,
                         (self.expr.id + 1).rename('id2')]
        res = self.engine.execute(expr)
        result = self._get_result(res.values)
        self.assertEqual([c.name for c in res.columns], ['const', 'id', 'id2'])
        self.assertTrue(all(it[0] == 3 for it in result))
        self.assertEqual(len(data), len(result))
        self.assertEqual([it[1] + 1 for it in data], [it[2] for it in result])

        expr = self.expr.sort('id')[:5]
        res = self.engine.execute(expr)
        result = self._get_result(res.values)
        self.assertEqual(sorted(data, key=lambda it: it[1])[:5], result)

        expr = self.expr.sort('id')[:5]
        # test do not use tunnel
        res = self.engine.execute(expr, use_tunnel=False)
        result = self._get_result(res.values)
        self.assertEqual(sorted(data, key=lambda it: it[1])[:5], result)

    def testElement(self):
        data = self._gen_data(5, nullable_field='name')

        fields = [
            self.expr.name.isnull().rename('name1'),
            self.expr.name.notnull().rename('name2'),
            self.expr.name.fillna('test').rename('name3'),
            self.expr.id.isin([1, 2, 3]).rename('id1'),
            self.expr.id.isin(self.expr.fid.astype('int')).rename('id2'),
            self.expr.id.notin([1, 2, 3]).rename('id3'),
            self.expr.id.notin(self.expr.fid.astype('int')).rename('id4'),
            self.expr.id.between(self.expr.fid, 3).rename('id5'),
            self.expr.name.fillna('test').switch(
                'test',
                'test' + self.expr.name.fillna('test'),
                'test2',
                'test2' + self.expr.name.fillna('test'),
                default=self.expr.name).rename('name4'),
            self.expr.id.cut([100, 200, 300],
                             labels=['xsmall', 'small', 'large', 'xlarge'],
                             include_under=True,
                             include_over=True).rename('id6')
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(data), len(result))

        self.assertEqual(len([it for it in data if it[0] is None]),
                         len([it[0] for it in result if it[0]]))

        self.assertEqual(len([it[0] for it in data if it[0] is not None]),
                         len([it[1] for it in result if it[1]]))

        self.assertEqual([(it[0] if it[0] is not None else 'test')
                          for it in data], [it[2] for it in result])

        self.assertEqual([(it[1] in (1, 2, 3)) for it in data],
                         [it[3] for it in result])

        fids = [int(it[2]) for it in data]
        self.assertEqual([(it[1] in fids) for it in data],
                         [it[4] for it in result])

        self.assertEqual([(it[1] not in (1, 2, 3)) for it in data],
                         [it[5] for it in result])

        self.assertEqual([(it[1] not in fids) for it in data],
                         [it[6] for it in result])

        self.assertEqual([(it[2] <= it[1] <= 3) for it in data],
                         [it[7] for it in result])

        self.assertEqual(
            [to_str('testtest' if it[0] is None else it[0]) for it in data],
            [to_str(it[8]) for it in result])

        def get_val(val):
            if val <= 100:
                return 'xsmall'
            elif 100 < val <= 200:
                return 'small'
            elif 200 < val <= 300:
                return 'large'
            else:
                return 'xlarge'

        self.assertEqual([to_str(get_val(it[1])) for it in data],
                         [to_str(it[9]) for it in result])

    def testArithmetic(self):
        data = self._gen_data(5, value_range=(-1000, 1000))

        fields = [
            (self.expr.id + 1).rename('id1'),
            (self.expr.fid - 1).rename('fid1'),
            (self.expr.scale * 2).rename('scale1'),
            (self.expr.scale + self.expr.id).rename('scale2'),
            (self.expr.id / 2).rename('id2'),
            (self.expr.id**-2).rename('id3'),
            abs(self.expr.id).rename('id4'),
            (~self.expr.id).rename('id5'),
            (-self.expr.fid).rename('fid2'),
            (~self.expr.isMale).rename('isMale1'),
            (-self.expr.isMale).rename('isMale2'),
            (self.expr.id // 2).rename('id6'),
            (self.expr.birth + day(1).rename('birth1')),
            (self.expr.birth -
             (self.expr.birth - millisecond(10))).rename('birth2'),
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(data), len(result))

        self.assertEqual([it[1] + 1 for it in data], [it[0] for it in result])

        self.assertAlmostEqual([it[2] - 1 for it in data],
                               [it[1] for it in result])

        self.assertEqual([it[4] * 2 for it in data], [it[2] for it in result])

        self.assertEqual([it[4] + it[1] for it in data],
                         [it[3] for it in result])

        self.assertAlmostEqual([float(it[1]) / 2 for it in data],
                               [it[4] for it in result])

        self.assertEqual([int(it[1]**-2) for it in data],
                         [it[5] for it in result])

        self.assertEqual([abs(it[1]) for it in data], [it[6] for it in result])

        self.assertEqual([~it[1] for it in data], [it[7] for it in result])

        self.assertAlmostEqual([-it[2] for it in data],
                               [it[8] for it in result])

        self.assertEqual([not it[3] for it in data], [it[9] for it in result])

        self.assertEqual([it[1] // 2 for it in data],
                         [it[11] for it in result])

        self.assertEqual([it[5] + timedelta(days=1) for it in data],
                         [it[12] for it in result])

        self.assertEqual([10] * len(data), [it[13] for it in result])

    def testMath(self):
        data = self._gen_data(5, value_range=(1, 90))

        import numpy as np

        methods_to_fields = [
            (np.sin, self.expr.id.sin()),
            (np.cos, self.expr.id.cos()),
            (np.tan, self.expr.id.tan()),
            (np.sinh, self.expr.id.sinh()),
            (np.cosh, self.expr.id.cosh()),
            (np.tanh, self.expr.id.tanh()),
            (np.log, self.expr.id.log()),
            (np.log2, self.expr.id.log2()),
            (np.log10, self.expr.id.log10()),
            (np.log1p, self.expr.id.log1p()),
            (np.exp, self.expr.id.exp()),
            (np.expm1, self.expr.id.expm1()),
            (np.arccosh, self.expr.id.arccosh()),
            (np.arcsinh, self.expr.id.arcsinh()),
            (np.arctanh, self.expr.id.arctanh()),
            (np.arctan, self.expr.id.arctan()),
            (np.sqrt, self.expr.id.sqrt()),
            (np.abs, self.expr.id.abs()),
            (np.ceil, self.expr.id.ceil()),
            (np.floor, self.expr.id.floor()),
            (np.trunc, self.expr.id.trunc()),
        ]

        fields = [
            it[1].rename('id' + str(i))
            for i, it in enumerate(methods_to_fields)
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = [method(it[1]) for it in data]
            second = [it[i] for it in result]
            self.assertEqual(len(first), len(second))
            for it1, it2 in zip(first, second):
                if np.isnan(it1) and np.isnan(it2):
                    continue
                self.assertAlmostEqual(it1, it2)

    def testString(self):
        data = self._gen_data(5)

        methods_to_fields = [
            (lambda s: s.capitalize(), self.expr.name.capitalize()),
            (lambda s: data[0][0] in s,
             self.expr.name.contains(data[0][0], regex=False)),
            (lambda s: s.count(data[0][0]), self.expr.name.count(data[0][0])),
            (lambda s: s.endswith(data[0][0]),
             self.expr.name.endswith(data[0][0])),
            (lambda s: s.startswith(data[0][0]),
             self.expr.name.startswith(data[0][0])),
            (lambda s: s.find(data[0][0]), self.expr.name.find(data[0][0])),
            (lambda s: s.rfind(data[0][0]), self.expr.name.rfind(data[0][0])),
            (lambda s: s.replace(data[0][0], 'test'),
             self.expr.name.replace(data[0][0], 'test')),
            (lambda s: s[0], self.expr.name.get(0)),
            (lambda s: len(s), self.expr.name.len()),
            (lambda s: s.ljust(10), self.expr.name.ljust(10)),
            (lambda s: s.ljust(20, '*'), self.expr.name.ljust(20,
                                                              fillchar='*')),
            (lambda s: s.rjust(10), self.expr.name.rjust(10)),
            (lambda s: s.rjust(20, '*'), self.expr.name.rjust(20,
                                                              fillchar='*')),
            (lambda s: s * 4, self.expr.name.repeat(4)),
            (lambda s: s[2:10:2], self.expr.name.slice(2, 10, 2)),
            (lambda s: s[-5:-1], self.expr.name.slice(-5, -1)),
            (lambda s: s.title(), self.expr.name.title()),
            (lambda s: s.rjust(20, '0'), self.expr.name.zfill(20)),
            (lambda s: s.isalnum(), self.expr.name.isalnum()),
            (lambda s: s.isalpha(), self.expr.name.isalpha()),
            (lambda s: s.isdigit(), self.expr.name.isdigit()),
            (lambda s: s.isspace(), self.expr.name.isspace()),
            (lambda s: s.isupper(), self.expr.name.isupper()),
            (lambda s: s.istitle(), self.expr.name.istitle()),
            (lambda s: to_str(s).isnumeric(), self.expr.name.isnumeric()),
            (lambda s: to_str(s).isdecimal(), self.expr.name.isdecimal()),
        ]

        fields = [
            it[1].rename('id' + str(i))
            for i, it in enumerate(methods_to_fields)
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = [method(it[0]) for it in data]
            second = [it[i] for it in result]
            self.assertEqual(first, second)

    def testApply(self):
        data = self._gen_data(5)

        def my_func(row):
            return row.name,

        expr = self.expr['name', 'id'].apply(my_func, axis=1, names='name')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual([r[0] for r in result], [r[0] for r in data])

        def my_func2(row):
            yield len(row.name)
            yield row.id

        expr = self.expr['name', 'id'].apply(my_func2,
                                             axis=1,
                                             names='cnt',
                                             types='int')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        def gen_expected(data):
            for r in data:
                yield len(r[0])
                yield r[1]

        self.assertEqual([r[0] for r in result],
                         [r for r in gen_expected(data)])

    def testDatetime(self):
        data = self._gen_data(5)

        import pandas as pd

        methods_to_fields = [
            (lambda s: list(s.birth.dt.year.values), self.expr.birth.year),
            (lambda s: list(s.birth.dt.month.values), self.expr.birth.month),
            (lambda s: list(s.birth.dt.day.values), self.expr.birth.day),
            (lambda s: list(s.birth.dt.hour.values), self.expr.birth.hour),
            (lambda s: list(s.birth.dt.minute.values), self.expr.birth.minute),
            (lambda s: list(s.birth.dt.second.values), self.expr.birth.second),
            (lambda s: list(s.birth.dt.weekofyear.values),
             self.expr.birth.weekofyear),
            (lambda s: list(s.birth.dt.dayofweek.values),
             self.expr.birth.dayofweek),
            (lambda s: list(s.birth.dt.weekday.values),
             self.expr.birth.weekday),
            (lambda s: list(s.birth.dt.date.values), self.expr.birth.date),
            (lambda s: list(s.birth.dt.strftime('%Y%d')),
             self.expr.birth.strftime('%Y%d'))
        ]

        fields = [
            it[1].rename('birth' + str(i))
            for i, it in enumerate(methods_to_fields)
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        df = pd.DataFrame(data, columns=self.schema.names)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = method(df)

            def conv(v):
                if isinstance(v, pd.Timestamp):
                    return v.to_datetime().date()
                else:
                    return v

            second = [conv(it[i]) for it in result]
            self.assertEqual(first, second)

    def testSortDistinct(self):
        data = [
            ['name1', 4, None, None, None, None],
            ['name2', 2, None, None, None, None],
            ['name1', 4, None, None, None, None],
            ['name1', 3, None, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.sort(['name', -self.expr.id
                               ]).distinct(['name', lambda x: x.id + 1])[:50]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(result), 3)

        expected = [['name1', 5], ['name1', 4], ['name2', 3]]
        self.assertEqual(expected, result)

    def testGroupbyAggregation(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.groupby(['name', 'id'])[lambda x: x.fid.min() * 2 < 8] \
            .agg(self.expr.fid.max() + 1, new_id=self.expr.id.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['name1', 3, 5.1, 6], ['name2', 2, 4.5, 2]]

        result = sorted(result, key=lambda k: k[0])

        self.assertEqual(expected, result)

        field = self.expr.groupby('name').sort(['id',
                                                -self.expr.fid]).row_number()
        expr = self.expr['name', 'id', 'fid', field]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [
            ['name1', 3, 4.1, 1],
            ['name1', 3, 2.2, 2],
            ['name1', 4, 5.3, 3],
            ['name1', 4, 4.2, 4],
            ['name2', 2, 3.5, 1],
        ]

        result = sorted(result, key=lambda k: (k[0], k[1], -k[2]))

        self.assertEqual(expected, result)

        expr = self.expr.name.value_counts()[:25]

        expected = [['name1', 4], ['name2', 1]]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

        expr = self.expr.name.topk(25)

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

        expr = self.expr.groupby('name').count()

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual([it[1:] for it in expected], result)

        expected = [['name1', 2], ['name2', 1]]

        expr = self.expr.groupby('name').id.nunique()

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual([it[1:] for it in expected], result)

        expr = self.expr[self.expr['id'] > 2].name.value_counts()[:25]

        expected = [['name1', 4]]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

    def testJoinGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])

        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2,
                               _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [['name1', 4, -1], ['name2', 1, -2]]

        self.odps.write_table(table2, 0,
                              [table2.new_record(values=d) for d in data2])

        expr = self.expr.join(expr2, on='name')[self.expr]
        expr = expr.groupby('id').agg(expr.fid.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        import pandas as pd
        expected = pd.DataFrame(data, columns=self.expr.schema.names).groupby('id').agg({'fid': 'sum'})\
            .reset_index().values.tolist()
        for it in zip(sorted(expected, key=lambda it: it[0]),
                      sorted(result, key=lambda it: it[0])):
            self.assertAlmostEqual(it[0][0], it[1][0])
            self.assertAlmostEqual(it[0][1], it[1][1])

    def testFilterGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.groupby(
            ['name']).agg(id=self.expr.id.max())[lambda x: x.id > 3]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(result), 1)

        expected = [['name1', 4]]

        self.assertEqual(expected, result)

    def testWindowRewrite(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr[self.expr.id - self.expr.id.mean() < 10][[
            lambda x: x.id - x.id.max()
        ]][[lambda x: x.id - x.id.min()]][lambda x: x.id - x.id.std() > 0]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        import pandas as pd
        df = pd.DataFrame(data, columns=self.schema.names)
        expected = df.id - df.id.max()
        expected = expected - expected.min()
        expected = list(expected[expected - expected.std() > 0])

        self.assertEqual(expected, [it[0] for it in result])

    def testReduction(self):
        data = self._gen_data(rows=5, value_range=(-100, 100))

        import pandas as pd
        df = pd.DataFrame(data, columns=self.schema.names)

        methods_to_fields = [
            (lambda s: df.id.mean(), self.expr.id.mean()),
            (lambda s: len(df), self.expr.count()),
            (lambda s: df.id.var(ddof=0), self.expr.id.var(ddof=0)),
            (lambda s: df.id.std(ddof=0), self.expr.id.std(ddof=0)),
            (lambda s: df.id.median(), self.expr.id.median()),
            (lambda s: df.id.sum(), self.expr.id.sum()),
            (lambda s: df.id.min(), self.expr.id.min()),
            (lambda s: df.id.max(), self.expr.id.max()),
            (lambda s: df.isMale.min(), self.expr.isMale.min()),
            (lambda s: df.name.max(), self.expr.name.max()),
            (lambda s: df.birth.max(), self.expr.birth.max()),
            (lambda s: df.isMale.sum(), self.expr.isMale.sum()),
            (lambda s: df.isMale.any(), self.expr.isMale.any()),
            (lambda s: df.isMale.all(), self.expr.isMale.all()),
            (lambda s: df.name.nunique(), self.expr.name.nunique()),
        ]

        fields = [
            it[1].rename('f' + str(i))
            for i, it in enumerate(methods_to_fields)
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        df = pd.DataFrame(data, columns=self.schema.names)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = method(df)
            second = [it[i] for it in result][0]
            if isinstance(first, float):
                self.assertAlmostEqual(first, second)
            else:
                self.assertEqual(first, second)

    def testMapReduceByApplyDistributeSort(self):
        data = [
            ['name key', 4, 5.3, None, None, None],
            ['name', 2, 3.5, None, None, None],
            ['key', 4, 4.2, None, None, None],
            ['name', 3, 2.2, None, None, None],
            ['key name', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        def mapper(row):
            for word in row[0].split():
                yield word, 1

        class reducer(object):
            def __init__(self):
                self._curr = None
                self._cnt = 0

            def __call__(self, row):
                if self._curr is None:
                    self._curr = row.word
                elif self._curr != row.word:
                    yield (self._curr, self._cnt)
                    self._curr = row.word
                    self._cnt = 0
                self._cnt += row.count

            def close(self):
                if self._curr is not None:
                    yield (self._curr, self._cnt)

        expr = self.expr['name', ].apply(mapper,
                                         axis=1,
                                         names=['word', 'count'],
                                         types=['string', 'int'])
        expr = expr.groupby('word').sort('word').apply(reducer,
                                                       names=['word', 'count'],
                                                       types=['string', 'int'])

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 3], ['name', 4]]
        self.assertEqual(sorted(result), sorted(expected))

    def testMapReduce(self):
        data = [
            ['name key', 4, 5.3, None, None, None],
            ['name', 2, 3.5, None, None, None],
            ['key', 4, 4.2, None, None, None],
            ['name', 3, 2.2, None, None, None],
            ['key name', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        @output(['word', 'cnt'], ['string', 'int'])
        def mapper(row):
            for word in row[0].split():
                yield word, 1

        @output(['word', 'cnt'], ['string', 'int'])
        def reducer(keys):
            cnt = [
                0,
            ]

            def h(row, done):
                cnt[0] += row[1]
                if done:
                    yield keys[0], cnt[0]

            return h

        expr = self.expr['name', ].map_reduce(mapper, reducer, group='word')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 3], ['name', 4]]
        self.assertEqual(sorted(result), sorted(expected))

        @output(['word', 'cnt'], ['string', 'int'])
        class reducer2(object):
            def __init__(self, keys):
                self.cnt = 0

            def __call__(self, row, done):
                self.cnt += row.cnt
                if done:
                    yield row.word, self.cnt

        expr = self.expr['name', ].map_reduce(mapper, reducer2, group='word')

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 3], ['name', 4]]
        self.assertEqual(sorted(result), sorted(expected))

    def testDistributeSort(self):
        data = [
            ['name', 4, 5.3, None, None, None],
            ['name', 2, 3.5, None, None, None],
            ['key', 4, 4.2, None, None, None],
            ['name', 3, 2.2, None, None, None],
            ['key', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        @output_names('name', 'id')
        @output_types('string', 'int')
        class reducer(object):
            def __init__(self):
                self._curr = None
                self._cnt = 0

            def __call__(self, row):
                if self._curr is None:
                    self._curr = row.name
                elif self._curr != row.name:
                    yield (self._curr, self._cnt)
                    self._curr = row.name
                    self._cnt = 0
                self._cnt += 1

            def close(self):
                if self._curr is not None:
                    yield (self._curr, self._cnt)

        expr = self.expr['name', ].groupby('name').sort('name').apply(reducer)

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [['key', 2], ['name', 3]]
        self.assertEqual(sorted(expected), sorted(result))

    def testJoin(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2,
                               _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [['name1', 4, -1], ['name2', 1, -2]]

        self.odps.write_table(table2, 0,
                              [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr.join(expr2)['name', 'id2']

            res = self.engine.execute(expr)
            result = self._get_result(res)

            self.assertEqual(len(result), 5)
            expected = [[to_str('name1'), 4], [to_str('name2'), 1]]
            self.assertTrue(all(it in expected for it in result))

            expr = self.expr.join(expr2, on=['name',
                                             ('id', 'id2')])[self.expr.name,
                                                             expr2.id2]
            res = self.engine.execute(expr)
            result = self._get_result(res)
            self.assertEqual(len(result), 2)
            expected = [to_str('name1'), 4]
            self.assertTrue(all(it == expected for it in result))

        finally:
            table2.drop()

    def testUnion(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2,
                               _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [['name3', 5, -1], ['name4', 6, -2]]

        self.odps.write_table(table2, 0,
                              [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr['name', 'id'].distinct().union(
                expr2[expr2.id2.rename('id'), 'name'])

            res = self.engine.execute(expr)
            result = self._get_result(res)

            expected = [['name1', 4], ['name1', 3], ['name2', 2], ['name3', 5],
                        ['name4', 6]]

            result = sorted(result)
            expected = sorted(expected)

            self.assertEqual(len(result), len(expected))
            for e, r in zip(result, expected):
                self.assertEqual([to_str(t) for t in e],
                                 [to_str(t) for t in r])

        finally:
            table2.drop()

    def testPersist(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        table_name = 'pyodps_test_engine_persist_table'

        try:
            df = self.engine.persist(self.expr, table_name)

            res = self.engine.execute(df)
            result = self._get_result(res)
            self.assertEqual(len(result), 5)
            self.assertEqual(data, result)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        try:
            schema = Schema.from_lists(self.schema.names, self.schema.types,
                                       ['ds'], ['string'])
            self.odps.create_table(table_name, schema)
            df = self.engine.persist(self.expr,
                                     table_name,
                                     partition='ds=today',
                                     create_partition=True)

            res = self.engine.execute(df)
            result = self._get_result(res)
            self.assertEqual(len(result), 5)
            self.assertEqual(data, [d[:-1] for d in result])
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        try:
            self.engine.persist(self.expr, table_name, partitions=['name'])

            t = self.odps.get_table(table_name)
            self.assertEqual(2, len(list(t.partitions)))
            with t.open_reader(partition='name=name1', reopen=True) as r:
                self.assertEqual(4, r.count)
            with t.open_reader(partition='name=name2', reopen=True) as r:
                self.assertEqual(1, r.count)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

    def teardown(self):
        self.table.drop()
class Test(TestBase):
    def setup(self):
        from odps.df.expr.tests.core import MockTable
        schema = Schema.from_lists(types._data_types.keys(), types._data_types.values())
        self.expr = CollectionExpr(_source_data=None, _schema=schema)
        self.sourced_expr = CollectionExpr(_source_data=MockTable(client=self.odps.rest), _schema=schema)

    def testSort(self):
        sorted_expr = self.expr.sort(self.expr.int64)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True])

        sorted_expr = self.expr.sort_values([self.expr.float32, 'string'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True] * 2)

        sorted_expr = self.expr.sort([self.expr.decimal, 'boolean', 'string'], ascending=False)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False] * 3)

        sorted_expr = self.expr.sort([self.expr.int8, 'datetime', 'float64'],
                                     ascending=[False, True, False])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, False])

        sorted_expr = self.expr.sort([-self.expr.int8, 'datetime', 'float64'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, True])

    def testDistinct(self):
        distinct = self.expr.distinct()
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr._schema)

        distinct = self.expr.distinct(self.expr.string)
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[[self.expr.string]]._schema)

        distinct = self.expr.distinct([self.expr.boolean, 'decimal'])
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[['boolean', 'decimal']]._schema)

        self.assertRaises(ExpressionError, lambda: self.expr['boolean', self.expr.string.unique()])

    def testMapReduce(self):
        @output(['id', 'name', 'rating'], ['int', 'string', 'int'])
        def mapper(row):
            yield row.int64, row.string, row.int32

        @output(['name', 'rating'], ['string', 'int'])
        def reducer(_):
            i = [0]
            def h(row):
                if i[0] <= 1:
                    yield row.name, row.rating
            return h

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort='rating', ascending=False)
        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 2)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort=['rating', 'id'], ascending=[False, True])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertTrue(expr._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort=['rating', 'id'], ascending=False)

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertFalse(expr._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort=['name', 'rating', 'id'],
                                    ascending=[False, True, False])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertFalse(expr._sort_fields[0]._ascending)
        self.assertTrue(expr._sort_fields[1]._ascending)
        self.assertFalse(expr._sort_fields[2]._ascending)

    def testSample(self):
        self.assertIsInstance(self.expr.sample(100), SampledCollectionExpr)
        self.assertIsInstance(self.expr.sample(parts=10), SampledCollectionExpr)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(self.expr.sample(frac=0.5), DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(self.expr.sample(frac=0.5), SampledCollectionExpr)
        self.assertIsInstance(self.sourced_expr.sample(frac=0.5), DataFrame)

        self.assertRaises(ExpressionError, lambda: self.expr.sample())
        self.assertRaises(ExpressionError, lambda: self.expr.sample(i=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=100, frac=0.5))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=100, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=100, frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=1.5))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(parts=10, i=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(parts=10, i=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(parts=10, n=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(weights='weights', strata='strata'))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac='Yes:10', strata='strata'))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=set(), strata='strata'))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=set(), strata='strata'))

    def testPivot(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot('string', 'int8', 'float32')

        self.assertIn('string', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 1)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        expr = self.expr.pivot(
            ['string', 'int8'], 'int16', ['datetime', 'string'])

        self.assertIn('string', expr._schema._name_indexes)
        self.assertIn('int8', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 2)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        self.assertRaises(ValueError, lambda: self.expr.pivot(
            ['string', 'int8'], ['datetime', 'string'], 'int16'))

    def testPivotTable(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot_table(values='int8', rows='float32')
        self.assertNotIsInstance(expr, DynamicMixin)
        self.assertEqual(expr.schema.names, ['float32', 'int8_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'), rows=['float32', 'int8'])
        self.assertEqual(expr.schema.names, ['float32', 'int8', 'int16_mean', 'int32_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'), rows=['string', 'boolean'],
                                     aggfunc=['mean', 'sum'])
        self.assertEqual(expr.schema.names, ['string', 'boolean', 'int16_mean', 'int32_mean',
                                             'int16_sum', 'int32_sum'])
        self.assertEqual(expr.schema.types, [types.string, types.boolean, types.float64, types.float64,
                                             types.int16, types.int32])

        @output(['my_mean'], ['float'])
        class Aggregator(object):
            def buffer(self):
                return [0.0, 0]

            def __call__(self, buffer, val):
                buffer[0] += val
                buffer[1] += 1

            def merge(self, buffer, pbuffer):
                buffer[0] += pbuffer[0]
                buffer[1] += pbuffer[1]

            def getvalue(self, buffer):
                if buffer[1] == 0:
                    return 0.0
                return buffer[0] / buffer[1]

        expr = self.expr.pivot_table(values='int16', rows='string', aggfunc=Aggregator)
        self.assertEqual(expr.schema.names, ['string', 'int16_my_mean'])
        self.assertEqual(expr.schema.types, [types.string, types.float64])

        aggfunc = OrderedDict([('my_agg', Aggregator), ('my_agg2', Aggregator)])

        expr = self.expr.pivot_table(values='int16', rows='string', aggfunc=aggfunc)
        self.assertEqual(expr.schema.names, ['string', 'int16_my_agg', 'int16_my_agg2'])
        self.assertEqual(expr.schema.types, [types.string, types.float64, types.float64])

        expr = self.expr.pivot_table(values='int16', columns='boolean', rows='string')
        self.assertIsInstance(expr, DynamicMixin)

    def testScaleValues(self):
        expr = self.expr.min_max_scale()
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names)

        expr = self.expr.min_max_scale(preserve=True)
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names +
                             [n + '_scaled' for n in self.expr.dtypes.names
                              if n.startswith('int') or n.startswith('float')])

        expr = self.expr.std_scale()
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names)

        expr = self.expr.std_scale(preserve=True)
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names +
                             [n + '_scaled' for n in self.expr.dtypes.names
                              if n.startswith('int') or n.startswith('float')])

    def testAppendId(self):
        expr = self.expr.append_id(id_col='id_col')
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(expr, DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(expr, AppendIDCollectionExpr)
        self.assertIn('id_col', expr.schema)

        self.assertIsInstance(self.sourced_expr.append_id(), DataFrame)

    def testSplit(self):
        expr1, expr2 = self.expr.split(0.6)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(expr1, DataFrame)
            self.assertIsInstance(expr2, DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(expr1, SplitCollectionExpr)
            self.assertIsInstance(expr2, SplitCollectionExpr)
            self.assertTupleEqual((expr1._split_id, expr2._split_id), (0, 1))
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'),
                                   ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=schema)

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema)

        schema2 = Schema.from_lists(['name', 'id', 'fid'], datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'], datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=schema2)

        self.maxDiff = None

    def testFilterPushdownThroughProjection(self):
        expr = self.expr[self.expr.id + 1, 'name'][lambda x: x.id < 10]

        expected = 'SELECT t1.`id` + 1 AS `id`, t1.`name` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE (t1.`id` + 1) < 10'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr[self.expr.id + 1, 'name', self.expr.name.isnull().rename('is_null')][lambda x: x.is_null]

        expected = 'SELECT t1.`id` + 1 AS `id`, t1.`name`, t1.`name` IS NULL AS `is_null` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`name` IS NULL'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr['name', self.expr.id ** 2]\
            .filter(lambda x: x.name == 'name1').filter(lambda x: x.id < 3)
        expected = "SELECT t1.`name`, CAST(POW(t1.`id`, 2) AS BIGINT) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND ((CAST(POW(t1.`id`, 2) AS BIGINT)) < 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr['name', self.expr.id + 1].filter(lambda x: x.name == 'name1')[
            lambda x: 'tt' + x.name, 'id'
        ].filter(lambda x: x.id < 3)

        expected = "SELECT CONCAT('tt', t1.`name`) AS `name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND ((t1.`id` + 1) < 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1').select('name', lambda x: (x.id + 1) * 2)[
            lambda x: 'tt' + x.name, 'id'
        ].filter(lambda x: x.id < 3)
        expected = "SELECT CONCAT('tt', t1.`name`) AS `name`, (t1.`id` + 1) * 2 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND (((t1.`id` + 1) * 2) < 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.id.between(2, 6),
                                self.expr.name.lower().contains('pyodps', regex=False)).name.nunique()
        expected = "SELECT COUNT(DISTINCT t2.`name`) AS `name_nunique` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id`, t1.`name` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE ((t1.`id` >= 2) AND (t1.`id` <= 6)) AND INSTR(TOLOWER(t1.`name`), 'pyodps') > 0 \n" \
                   ") t2"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPushDownThroughJoin(self):
        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x < 10) & (expr.fid_y > 3)]

        expected = 'SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, ' \
                   't2.`isMale`, t2.`scale`, t2.`birth`, t2.`ds`, t4.`id` AS `id_y`, ' \
                   't4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE t1.`id` < 10\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`fid` > 3\n' \
                   '  ) t4\n' \
                   'ON t2.`name` == t4.`name`'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x < 10) & (expr.fid_y > 3) & (expr.id_x > 3)]

        expected = 'SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, ' \
                   't2.`isMale`, t2.`scale`, t2.`birth`, t2.`ds`, t4.`id` AS `id_y`, ' \
                   't4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` < 10) AND (t1.`id` > 3)\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`fid` > 3\n' \
                   '  ) t4\n' \
                   'ON t2.`name` == t4.`name`'

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr[self.expr.name, self.expr.id + 1]
        expr2 = self.expr3['tt' + self.expr3.name, self.expr3.id.rename('id2')]
        expr = expr.join(expr2, on='name')
        expr = expr[((expr.id < 10) | (expr.id > 100)) & (expr.id2 > 3)]

        expected = "SELECT t2.`name`, t2.`id`, t4.`id2` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE ((t1.`id` + 1) < 10) OR ((t1.`id` + 1) > 100)\n" \
                   ") t2 \n" \
                   "INNER JOIN \n" \
                   "  (\n" \
                   "    SELECT CONCAT('tt', t3.`name`) AS `name`, t3.`id` AS `id2` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    WHERE t3.`id` > 3\n" \
                   "  ) t4\n" \
                   "ON t2.`name` == t4.`name`"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x + expr.id_y < 10) & (expr.id_x > 3)]

        expected = "SELECT * \n" \
                   "FROM (\n" \
                   "  SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, t2.`isMale`, " \
                   "t2.`scale`, t2.`birth`, t2.`ds`, t3.`id` AS `id_y`, " \
                   "t3.`fid` AS `fid_y`, t3.`part1`, t3.`part2` \n" \
                   "  FROM (\n" \
                   "    SELECT * \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "    WHERE t1.`id` > 3\n" \
                   "  ) t2 \n" \
                   "  INNER JOIN \n" \
                   "    mocked_project.`pyodps_test_expr_table2` t3\n" \
                   "  ON t2.`name` == t3.`name` \n" \
                   ") t4 \n" \
                   "WHERE (t4.`id_x` + t4.`id_y`) < 10"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.outer_join(self.expr3, on='name')
        expr = expr[(expr.id_x + expr.id_y < 10) & (expr.id_x > 3)]

        expected = "SELECT * \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name` AS `name_x`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, " \
                   "t1.`isMale`, t1.`scale`, t1.`birth`, t1.`ds`, t2.`name` AS `name_y`, " \
                   "t2.`id` AS `id_y`, t2.`fid` AS `fid_y`, t2.`part1`, t2.`part2` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  FULL OUTER JOIN \n" \
                   "    mocked_project.`pyodps_test_expr_table2` t2\n" \
                   "  ON t1.`name` == t2.`name` \n" \
                   ") t3 \n" \
                   "WHERE ((t3.`id_x` + t3.`id_y`) < 10) AND (t3.`id_x` > 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on=['name', self.expr.id == self.expr3.id,
                                              self.expr.id < 10, self.expr3.name == 'name1',
                                              self.expr.id > 5])

        expected = 'SELECT t2.`name`, t2.`id`, t2.`fid` AS `fid_x`, t2.`isMale`, ' \
                   't2.`scale`, t2.`birth`, t2.`ds`, t4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` < 10) AND (t1.`id` > 5)\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`name` == \'name1\'\n' \
                   '  ) t4\n' \
                   'ON (t2.`name` == t4.`name`) AND (t2.`id` == t4.`id`)'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.left_join(self.expr3, on=['name', self.expr.id == self.expr3.id,
                                                   self.expr.id < 10, self.expr3.name == 'name1',
                                                   self.expr.id > 5])
        expected = 'SELECT t1.`name` AS `name_x`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, t1.`isMale`, ' \
                   't1.`scale`, t1.`birth`, t1.`ds`, t2.`name` AS `name_y`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`part1`, t2.`part2` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'LEFT OUTER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON ((((t1.`name` == t2.`name`) AND (t1.`id` == t2.`id`)) ' \
                   "AND (t1.`id` < 10)) AND (t2.`name` == 'name1')) AND (t1.`id` > 5)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPushdownThroughUnion(self):
        expr = self.expr['name', 'id'].union(self.expr2['id', 'name'])
        expr = expr.filter(expr.id + 1 < 3)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT t1.`name`, t1.`id` \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` + 1) < 3 \n' \
                   '  UNION ALL\n' \
                   '    SELECT t2.`name`, t2.`id` \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2 \n' \
                   '    WHERE (t2.`id` + 1) < 3\n' \
                   ') t3'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr1 = self.expr.filter(self.expr.id == 1)['name', 'id']
        expr2 = self.expr.filter(self.expr.id == 0)['id', 'name']
        expr = expr1.union(expr2)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT t1.`name`, t1.`id` \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE t1.`id` == 1 \n' \
                   '  UNION ALL\n' \
                   '    SELECT t2.`name`, t2.`id` \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table` t2 \n' \
                   '    WHERE t2.`id` == 0\n' \
                   ') t3'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testGroupbyProjection(self):
        expr = self.expr['id', 'name', 'fid']
        expr2 = expr.groupby('name').agg(count=expr.count(), id=expr.id.sum())
        expr3 = expr2['count', 'id']

        expected = "SELECT COUNT(1) AS `count`, SUM(t1.`id`) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name`"

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr3, prettify=False)))

        expr = self.expr['id', 'name', 'fid'].filter(self.expr.id < 10)['name', 'id']
        expr2 = expr.groupby('name').agg(count=expr.count(), id=expr.id.sum(), name2=expr.name.max())
        expr3 = expr2[expr2.count + 1, 'id']

        expected = "SELECT COUNT(1) + 1 AS `count`, SUM(t1.`id`) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE t1.`id` < 10 \n" \
                   "GROUP BY t1.`name`"

        self.assertEqual(expected, ODPSEngine(self.odps).compile(expr3, prettify=False))

    def testFilterPushdownThroughMultipleProjection(self):
        schema = Schema.from_lists(list('abcde'), ['string']*5)
        table = MockTable(name='pyodps_test_expr_table3', schema=schema)
        tab = CollectionExpr(_source_data=table, _schema=odps_schema_to_df_schema(schema))

        labels2 = []
        bins2 = []
        for i in range(0, 30):
            a = str(7 * i) + '-' + str(7 * (i + 1))
            b = 7 * i
            bins2.append(b)
            labels2.append(a)

        p1 = tab.select(tab.a,
                        tab.c.astype('int').cut(bins2, labels=labels2, include_over=True).rename('c_cut'),
                        tab.e.astype('int').rename('e'),
                        tab.c.astype('int').rename('c'))
        p1['f'] = p1['e'] / p1['c']
        t = []
        l = []
        for i in range(0, 20):
            a = 1 * i
            b = str(a)
            t.append(a)
            l.append(b)
        p2 = p1.select(p1.a, p1.c_cut, p1.f.cut(bins=t, labels=l, include_over=True).rename('f_cut'))

        expected = "SELECT t1.`a`, CASE WHEN (0 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 7) THEN '0-7' " \
                   "WHEN (7 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 14) " \
                   "THEN '7-14' WHEN (14 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 21) THEN '14-21' " \
                   "WHEN (21 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 28) " \
                   "THEN '21-28' WHEN (28 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 35) THEN '28-35' " \
                   "WHEN (35 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 42) THEN '35-42' " \
                   "WHEN (42 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 49) THEN '42-49' " \
                   "WHEN (49 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 56) " \
                   "THEN '49-56' WHEN (56 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 63) THEN '56-63' " \
                   "WHEN (63 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 70) THEN '63-70' " \
                   "WHEN (70 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 77) " \
                   "THEN '70-77' WHEN (77 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 84) " \
                   "THEN '77-84' WHEN (84 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 91) THEN '84-91' " \
                   "WHEN (91 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 98) " \
                   "THEN '91-98' WHEN (98 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 105) THEN '98-105' " \
                   "WHEN (105 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 112) " \
                   "THEN '105-112' WHEN (112 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 119) THEN '112-119' " \
                   "WHEN (119 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 126) " \
                   "THEN '119-126' WHEN (126 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 133) THEN '126-133' " \
                   "WHEN (133 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 140) " \
                   "THEN '133-140' WHEN (140 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 147) THEN '140-147' " \
                   "WHEN (147 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 154) " \
                   "THEN '147-154' WHEN (154 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 161) THEN '154-161' " \
                   "WHEN (161 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 168) " \
                   "THEN '161-168' WHEN (168 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 175) THEN '168-175' " \
                   "WHEN (175 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 182) " \
                   "THEN '175-182' WHEN (182 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 189) THEN '182-189' " \
                   "WHEN (189 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 196) " \
                   "THEN '189-196' WHEN (196 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 203) THEN '196-203' " \
                   "WHEN 203 < CAST(t1.`c` AS BIGINT) THEN '203-210' END AS `c_cut`, " \
                   "CASE WHEN (0 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 1) THEN '0' " \
                   "WHEN (1 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 2) " \
                   "THEN '1' WHEN (2 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 3) THEN '2' " \
                   "WHEN (3 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 4) " \
                   "THEN '3' WHEN (4 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 5) THEN '4' " \
                   "WHEN (5 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 6) THEN '5' " \
                   "WHEN (6 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 7) " \
                   "THEN '6' WHEN (7 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 8) THEN '7' " \
                   "WHEN (8 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 9) THEN '8' " \
                   "WHEN (9 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 10) " \
                   "THEN '9' WHEN (10 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 11) THEN '10' " \
                   "WHEN (11 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 12) " \
                   "THEN '11' WHEN (12 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 13) THEN '12' " \
                   "WHEN (13 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 14) THEN '13' " \
                   "WHEN (14 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 15) THEN '14' " \
                   "WHEN (15 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 16) THEN '15' " \
                   "WHEN (16 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 17) THEN '16' " \
                   "WHEN (17 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 18) " \
                   "THEN '17' WHEN (18 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 19) THEN '18' " \
                   "WHEN 19 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) THEN '19' END AS `f_cut` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table3` t1 \n" \
                   "WHERE (CASE WHEN (0 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 1) THEN '0' " \
                   "WHEN (1 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 2) " \
                   "THEN '1' WHEN (2 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 3) THEN '2' " \
                   "WHEN (3 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 4) THEN '3' " \
                   "WHEN (4 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 5) THEN '4' " \
                   "WHEN (5 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 6) THEN '5' " \
                   "WHEN (6 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 7) THEN '6' " \
                   "WHEN (7 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 8) THEN '7' " \
                   "WHEN (8 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 9) THEN '8' " \
                   "WHEN (9 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 10) THEN '9' " \
                   "WHEN (10 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 11) THEN '10' " \
                   "WHEN (11 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 12) THEN '11' " \
                   "WHEN (12 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 13) THEN '12' " \
                   "WHEN (13 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 14) THEN '13' " \
                   "WHEN (14 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 15) THEN '14' " \
                   "WHEN (15 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 16) THEN '15' " \
                   "WHEN (16 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 17) THEN '16' " \
                   "WHEN (17 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 18) THEN '17' " \
                   "WHEN (18 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 19) THEN '18' " \
                   "WHEN 19 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) THEN '19' END) == '9'"

        self.assertEqual(str(expected), str(ODPSEngine(self.odps).compile(p2[p2.f_cut == '9'], prettify=False)))
Ejemplo n.º 36
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'))
        self.schema = df_schema_to_odps_schema(schema)
        table_name = 'pyodps_test_engine_table'
        self.odps.delete_table(table_name, if_exists=True)
        self.table = self.odps.create_table(
                name='pyodps_test_engine_table', schema=self.schema)
        self.expr = CollectionExpr(_source_data=self.table, _schema=schema)

        self.engine = ODPSEngine(self.odps)

        class FakeBar(object):
            def update(self, *args, **kwargs):
                pass
        self.faked_bar = FakeBar()

    def _gen_random_bigint(self, value_range=None):
        return random.randint(*(value_range or types.bigint._bounds))

    def _gen_random_string(self, max_length=15):
        gen_letter = lambda: letters[random.randint(0, 51)]
        return to_str(''.join([gen_letter() for _ in range(random.randint(1, 15))]))

    def _gen_random_double(self):
        return random.uniform(-2**32, 2**32)

    def _gen_random_datetime(self):
        return datetime.fromtimestamp(random.randint(0, int(time.time())))

    def _gen_random_boolean(self):
        return random.uniform(-1, 1) > 0

    def _gen_random_decimal(self):
        return Decimal(str(self._gen_random_double()))

    def _gen_data(self, rows=None, data=None, nullable_field=None, value_range=None):
        if data is None:
            data = []
            for _ in range(rows):
                record = []
                for t in self.schema.types:
                    method = getattr(self, '_gen_random_%s' % t.name)
                    if t.name == 'bigint':
                        record.append(method(value_range=value_range))
                    else:
                        record.append(method())
                data.append(record)

            if nullable_field is not None:
                j = self.schema._name_indexes[nullable_field]
                for i, l in enumerate(data):
                    if i % 2 == 0:
                        data[i][j] = None

        self.odps.write_table(self.table, 0, [self.table.new_record(values=d) for d in data])
        return data

    def _get_result(self, res):
        if isinstance(res, ResultFrame):
            res = res.values
        try:
            import pandas

            if isinstance(res, pandas.DataFrame):
                return [list(it) for it in res.values]
            else:
                return res
        except ImportError:
            return res

    def testTunnelCases(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr.count()
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(10, result)

        expr = self.expr.name.count()
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(10, result)

        res = self.engine._handle_cases(self.expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual(data, result)

        expr = self.expr['name', self.expr.id.rename('new_id')]
        res = self.engine._handle_cases(expr, self.faked_bar)
        result = self._get_result(res)
        self.assertEqual([it[:2] for it in data], result)

        table_name = 'pyodps_test_engine_partitioned'
        self.odps.delete_table(table_name, if_exists=True)
        df = self.expr.persist(table_name, partitions=['name'])

        try:
            expr = df.count()
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertIsNone(res)

            expr = df[df.name == data[0][0]]['fid', 'id'].count()
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(res, 0)

            expr = df[df.name == data[0][0]]['fid', 'id']
            res = self.engine._handle_cases(expr, self.faked_bar)
            self.assertGreater(len(res), 0)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

    def testBase(self):
        data = self._gen_data(10, value_range=(-1000, 1000))

        expr = self.expr[self.expr.id < 10]['name', lambda x: x.id]
        result = self._get_result(self.engine.execute(expr).values)
        self.assertEqual(len([it for it in data if it[1] < 10]), len(result))
        if len(result) > 0:
            self.assertEqual(2, len(result[0]))

        expr = self.expr[Scalar(3).rename('const'), self.expr.id, (self.expr.id + 1).rename('id2')]
        res = self.engine.execute(expr)
        result = self._get_result(res.values)
        self.assertEqual([c.name for c in res.columns], ['const', 'id', 'id2'])
        self.assertTrue(all(it[0] == 3 for it in result))
        self.assertEqual(len(data), len(result))
        self.assertEqual([it[1]+1 for it in data], [it[2] for it in result])

        expr = self.expr.sort('id')[:5]
        res = self.engine.execute(expr)
        result = self._get_result(res.values)
        self.assertEqual(sorted(data, key=lambda it: it[1])[:5], result)

    def testElement(self):
        data = self._gen_data(5, nullable_field='name')

        fields = [
            self.expr.name.isnull().rename('name1'),
            self.expr.name.notnull().rename('name2'),
            self.expr.name.fillna('test').rename('name3'),
            self.expr.id.isin([1, 2, 3]).rename('id1'),
            self.expr.id.isin(self.expr.fid.astype('int')).rename('id2'),
            self.expr.id.notin([1, 2, 3]).rename('id3'),
            self.expr.id.notin(self.expr.fid.astype('int')).rename('id4'),
            self.expr.id.between(self.expr.fid, 3).rename('id5'),
            self.expr.name.fillna('test').switch('test', 'test' + self.expr.name.fillna('test'),
                                                 'test2', 'test2' + self.expr.name.fillna('test'),
                                                 default=self.expr.name).rename('name4'),
            self.expr.id.cut([100, 200, 300],
                             labels=['xsmall', 'small', 'large', 'xlarge'],
                             include_under=True, include_over=True).rename('id6')
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(data), len(result))

        self.assertEqual(len([it for it in data if it[0] is None]),
                         len([it[0] for it in result if it[0]]))

        self.assertEqual(len([it[0] for it in data if it[0] is not None]),
                         len([it[1] for it in result if it[1]]))

        self.assertEqual([(it[0] if it[0] is not None else 'test') for it in data],
                         [it[2] for it in result])

        self.assertEqual([(it[1] in (1, 2, 3)) for it in data],
                         [it[3] for it in result])

        fids = [int(it[2]) for it in data]
        self.assertEqual([(it[1] in fids) for it in data],
                         [it[4] for it in result])

        self.assertEqual([(it[1] not in (1, 2, 3)) for it in data],
                         [it[5] for it in result])

        self.assertEqual([(it[1] not in fids) for it in data],
                         [it[6] for it in result])

        self.assertEqual([(it[2] <= it[1] <= 3) for it in data],
                         [it[7] for it in result])

        self.assertEqual([to_str('testtest' if it[0] is None else it[0]) for it in data],
                         [to_str(it[8]) for it in result])

        def get_val(val):
            if val <= 100:
                return 'xsmall'
            elif 100 < val <= 200:
                return 'small'
            elif 200 < val <= 300:
                return 'large'
            else:
                return 'xlarge'
        self.assertEqual([to_str(get_val(it[1])) for it in data], [to_str(it[9]) for it in result])

    def testArithmetic(self):
        data = self._gen_data(5, value_range=(-1000, 1000))

        fields = [
            (self.expr.id + 1).rename('id1'),
            (self.expr.fid - 1).rename('fid1'),
            (self.expr.scale * 2).rename('scale1'),
            (self.expr.scale + self.expr.id).rename('scale2'),
            (self.expr.id / 2).rename('id2'),
            (self.expr.id ** -2).rename('id3'),
            abs(self.expr.id).rename('id4'),
            (~self.expr.id).rename('id5'),
            (-self.expr.fid).rename('fid2'),
            (~self.expr.isMale).rename('isMale1'),
            (-self.expr.isMale).rename('isMale2'),
        ]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(data), len(result))

        self.assertEqual([it[1] + 1 for it in data],
                         [it[0] for it in result])

        self.assertAlmostEqual([it[2] - 1 for it in data],
                               [it[1] for it in result])

        self.assertEqual([it[4] * 2 for it in data],
                         [it[2] for it in result])

        self.assertEqual([it[4] + it[1] for it in data],
                         [it[3] for it in result])

        self.assertAlmostEqual([float(it[1]) / 2 for it in data],
                               [it[4] for it in result])

        self.assertEqual([int(it[1] ** -2) for it in data],
                         [it[5] for it in result])

        self.assertEqual([abs(it[1]) for it in data],
                         [it[6] for it in result])

        self.assertEqual([~it[1] for it in data],
                         [it[7] for it in result])

        self.assertAlmostEqual([-it[2] for it in data],
                               [it[8] for it in result])

        self.assertEqual([not it[3] for it in data],
                         [it[9] for it in result])

        # TODO: test the datetime add and substract

    def testMath(self):
        data = self._gen_data(5, value_range=(1, 90))

        import numpy as np

        methods_to_fields = [
            (np.sin, self.expr.id.sin()),
            (np.cos, self.expr.id.cos()),
            (np.tan, self.expr.id.tan()),
            (np.sinh, self.expr.id.sinh()),
            (np.cosh, self.expr.id.cosh()),
            (np.tanh, self.expr.id.tanh()),
            (np.log, self.expr.id.log()),
            (np.log2, self.expr.id.log2()),
            (np.log10, self.expr.id.log10()),
            (np.log1p, self.expr.id.log1p()),
            (np.exp, self.expr.id.exp()),
            (np.expm1, self.expr.id.expm1()),
            (np.arccosh, self.expr.id.arccosh()),
            (np.arcsinh, self.expr.id.arcsinh()),
            (np.arctanh, self.expr.id.arctanh()),
            (np.arctan, self.expr.id.arctan()),
            (np.sqrt, self.expr.id.sqrt()),
            (np.abs, self.expr.id.abs()),
            (np.ceil, self.expr.id.ceil()),
            (np.floor, self.expr.id.floor()),
            (np.trunc, self.expr.id.trunc()),
        ]

        fields = [it[1].rename('id'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = [method(it[1]) for it in data]
            second = [it[i] for it in result]
            self.assertEqual(len(first), len(second))
            for it1, it2 in zip(first, second):
                if np.isnan(it1) and np.isnan(it2):
                    continue
                self.assertAlmostEqual(it1, it2)

    def testString(self):
        data = self._gen_data(5)

        methods_to_fields = [
            (lambda s: s.capitalize(), self.expr.name.capitalize()),
            (lambda s: data[0][0] in s, self.expr.name.contains(data[0][0], regex=False)),
            (lambda s: s.count(data[0][0]), self.expr.name.count(data[0][0])),
            (lambda s: s.endswith(data[0][0]), self.expr.name.endswith(data[0][0])),
            (lambda s: s.startswith(data[0][0]), self.expr.name.startswith(data[0][0])),
            (lambda s: s.find(data[0][0]), self.expr.name.find(data[0][0])),
            (lambda s: s.rfind(data[0][0]), self.expr.name.rfind(data[0][0])),
            (lambda s: s.replace(data[0][0], 'test'), self.expr.name.replace(data[0][0], 'test')),
            (lambda s: s[0], self.expr.name.get(0)),
            (lambda s: len(s), self.expr.name.len()),
            (lambda s: s.ljust(10), self.expr.name.ljust(10)),
            (lambda s: s.ljust(20, '*'), self.expr.name.ljust(20, fillchar='*')),
            (lambda s: s.rjust(10), self.expr.name.rjust(10)),
            (lambda s: s.rjust(20, '*'), self.expr.name.rjust(20, fillchar='*')),
            (lambda s: s * 4, self.expr.name.repeat(4)),
            (lambda s: s[2: 10: 2], self.expr.name.slice(2, 10, 2)),
            (lambda s: s[-5: -1], self.expr.name.slice(-5, -1)),
            (lambda s: s.title(), self.expr.name.title()),
            (lambda s: s.rjust(20, '0'), self.expr.name.zfill(20)),
            (lambda s: s.isalnum(), self.expr.name.isalnum()),
            (lambda s: s.isalpha(), self.expr.name.isalpha()),
            (lambda s: s.isdigit(), self.expr.name.isdigit()),
            (lambda s: s.isspace(), self.expr.name.isspace()),
            (lambda s: s.isupper(), self.expr.name.isupper()),
            (lambda s: s.istitle(), self.expr.name.istitle()),
            (lambda s: to_str(s).isnumeric(), self.expr.name.isnumeric()),
            (lambda s: to_str(s).isdecimal(), self.expr.name.isdecimal()),
        ]

        fields = [it[1].rename('id'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = [method(it[0]) for it in data]
            second = [it[i] for it in result]
            self.assertEqual(first, second)

    def testDatetime(self):
        data = self._gen_data(5)

        import pandas as pd

        methods_to_fields = [
            (lambda s: list(s.birth.dt.year.values), self.expr.birth.year),
            (lambda s: list(s.birth.dt.month.values), self.expr.birth.month),
            (lambda s: list(s.birth.dt.day.values), self.expr.birth.day),
            (lambda s: list(s.birth.dt.hour.values), self.expr.birth.hour),
            (lambda s: list(s.birth.dt.minute.values), self.expr.birth.minute),
            (lambda s: list(s.birth.dt.second.values), self.expr.birth.second),
            (lambda s: list(s.birth.dt.weekofyear.values), self.expr.birth.weekofyear),
            (lambda s: list(s.birth.dt.dayofweek.values), self.expr.birth.dayofweek),
            (lambda s: list(s.birth.dt.weekday.values), self.expr.birth.weekday),
        ]

        fields = [it[1].rename('birth'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        df = pd.DataFrame(data, columns=self.schema.names)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = method(df)
            second = [it[i] for it in result]
            self.assertEqual(first, second)

    def testSortDistinct(self):
        data = [
            ['name1', 4, None, None, None, None],
            ['name2', 2, None, None, None, None],
            ['name1', 4, None, None, None, None],
            ['name1', 3, None, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.sort(['name', -self.expr.id]).distinct(['name', lambda x: x.id + 1])[:50]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(result), 3)

        expected = [
            ['name1', 5],
            ['name1', 4],
            ['name2', 3]
        ]
        self.assertEqual(expected, result)

    def testGroupbyAggregation(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.groupby(['name', 'id'])[lambda x: x.fid.min() * 2 < 8] \
            .agg(self.expr.fid.max() + 1, new_id=self.expr.id.sum())

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [
            ['name1', 3, 5.1, 6],
            ['name2', 2, 4.5, 2]
        ]

        result = sorted(result, key=lambda k: k[0])

        self.assertEqual(expected, result)

        field = self.expr.groupby('name').sort(['id', -self.expr.fid]).row_number()
        expr = self.expr['name', 'id', 'fid', field]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        expected = [
            ['name1', 3, 4.1, 1],
            ['name1', 3, 2.2, 2],
            ['name1', 4, 5.3, 3],
            ['name1', 4, 4.2, 4],
            ['name2', 2, 3.5, 1],
        ]

        result = sorted(result, key=lambda k: (k[0], k[1], -k[2]))

        self.assertEqual(expected, result)

        expr = self.expr.name.value_counts()[:25]

        expected = [
            ['name1', 4],
            ['name2', 1]
        ]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

        expr = self.expr.name.topk(25)

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

        expr = self.expr.groupby('name').count()

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(expected, result)

    def testFilterGroupby(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr.groupby(['name']).agg(id=self.expr.id.max())[lambda x: x.id > 3]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        self.assertEqual(len(result), 1)

        expected = [
            ['name1', 4]
        ]

        self.assertEqual(expected, result)

    def testWindowRewrite(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        expr = self.expr[self.expr.id - self.expr.id.mean() < 10][
            [lambda x: x.id - x.id.max()]][[lambda x: x.id - x.id.min()]][lambda x: x.id - x.id.std() > 0]

        # FIXME compiling too slow
        res = self.engine.execute(expr)
        result = self._get_result(res)

        import pandas as pd
        df = pd.DataFrame(data, columns=self.schema.names)
        expected = df.id - df.id.max()
        expected = expected - expected.min()
        expected = list(expected[expected - expected.std() > 0])

        self.assertEqual(expected, [it[0] for it in result])

    def testReduction(self):
        data = self._gen_data(rows=5, value_range=(-100, 100))

        import pandas as pd
        df = pd.DataFrame(data, columns=self.schema.names)

        methods_to_fields = [
            (lambda s: df.id.mean(), self.expr.id.mean()),
            (lambda s: len(df), self.expr.count()),
            (lambda s: df.id.var(ddof=0), self.expr.id.var(ddof=0)),
            (lambda s: df.id.std(ddof=0), self.expr.id.std(ddof=0)),
            (lambda s: df.id.median(), self.expr.id.median()),
            (lambda s: df.id.sum(), self.expr.id.sum()),
            (lambda s: df.id.min(), self.expr.id.min()),
            (lambda s: df.id.max(), self.expr.id.max()),
            (lambda s: df.isMale.min(), self.expr.isMale.min()),
            (lambda s: df.name.max(), self.expr.name.max()),
            (lambda s: df.birth.max(), self.expr.birth.max()),
            (lambda s: df.name.sum(), self.expr.name.sum()),
            (lambda s: df.isMale.sum(), self.expr.isMale.sum()),
        ]

        fields = [it[1].rename('f'+str(i)) for i, it in enumerate(methods_to_fields)]

        expr = self.expr[fields]

        res = self.engine.execute(expr)
        result = self._get_result(res)

        df = pd.DataFrame(data, columns=self.schema.names)

        for i, it in enumerate(methods_to_fields):
            method = it[0]

            first = method(df)
            second = [it[i] for it in result][0]
            self.assertAlmostEqual(first, second)

    def testJoin(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name1', 4, -1],
            ['name2', 1, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr.join(expr2)['name', 'id2']

            res = self.engine.execute(expr)
            result = self._get_result(res)

            self.assertEqual(len(result), 5)
            expected = [
                [to_str('name1'), 4],
                [to_str('name2'), 1]
            ]
            self.assertTrue(all(it in expected for it in result))

            expr = self.expr.join(expr2, on=['name', ('id', 'id2')])[self.expr.name, expr2.id2]
            res = self.engine.execute(expr)
            result = self._get_result(res)
            self.assertEqual(len(result), 2)
            expected = [to_str('name1'), 4]
            self.assertTrue(all(it == expected for it in result))

        finally:
            table2.drop()

    def testUnion(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]

        schema2 = Schema.from_lists(['name', 'id2', 'id3'],
                                    [types.string, types.bigint, types.bigint])
        table_name = 'pyodps_test_engine_table2'
        self.odps.delete_table(table_name, if_exists=True)
        table2 = self.odps.create_table(name=table_name, schema=schema2)
        expr2 = CollectionExpr(_source_data=table2, _schema=odps_schema_to_df_schema(schema2))

        self._gen_data(data=data)

        data2 = [
            ['name3', 5, -1],
            ['name4', 6, -2]
        ]

        self.odps.write_table(table2, 0, [table2.new_record(values=d) for d in data2])

        try:
            expr = self.expr['name', 'id'].distinct().union(expr2[expr2.id2.rename('id'), 'name'])

            res = self.engine.execute(expr)
            result = self._get_result(res)

            expected = [
                ['name1', 4],
                ['name1', 3],
                ['name2', 2],
                ['name3', 5],
                ['name4', 6]
            ]

            result = sorted(result)
            expected = sorted(expected)

            self.assertEqual(len(result), len(expected))
            for e, r in zip(result, expected):
                self.assertEqual([to_str(t) for t in e],
                                 [to_str(t) for t in r])

        finally:
            table2.drop()

    def testPersist(self):
        data = [
            ['name1', 4, 5.3, None, None, None],
            ['name2', 2, 3.5, None, None, None],
            ['name1', 4, 4.2, None, None, None],
            ['name1', 3, 2.2, None, None, None],
            ['name1', 3, 4.1, None, None, None],
        ]
        self._gen_data(data=data)

        table_name = 'pyodps_test_engine_persist_table'

        try:
            df = self.expr.persist(table_name)

            res = self.engine.execute(df)
            result = self._get_result(res)
            self.assertEqual(len(result), 5)
            self.assertEqual(data, result)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

        try:
            self.expr.persist(table_name, partitions=['name'])

            t = self.odps.get_table(table_name)
            self.assertEqual(2, len(list(t.partitions)))
            with t.open_reader(partition='name=name1', reopen=True) as r:
                self.assertEqual(4, r.count)
            with t.open_reader(partition='name=name2', reopen=True) as r:
                self.assertEqual(1, r.count)
        finally:
            self.odps.delete_table(table_name, if_exists=True)

    def teardown(self):
        self.table.drop()
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'),
                                   ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=schema)

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema)

        schema2 = Schema.from_lists(['name', 'id', 'fid'], datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'], datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=schema2)

        self.maxDiff = None

    def testFilterPushdownThroughProjection(self):
        expr = self.expr[self.expr.id + 1, 'name'][lambda x: x.id < 10]

        expected = 'SELECT t1.`id` + 1 AS `id`, t1.`name` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE (t1.`id` + 1) < 10'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr['name', self.expr.id ** 2]\
            .filter(lambda x: x.name == 'name1').filter(lambda x: x.id < 3)
        expected = "SELECT t1.`name`, CAST(POW(t1.`id`, 2) AS BIGINT) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND ((CAST(POW(t1.`id`, 2) AS BIGINT)) < 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr['name', self.expr.id + 1].filter(lambda x: x.name == 'name1')[
            lambda x: 'tt' + x.name, 'id'
        ].filter(lambda x: x.id < 3)

        expected = "SELECT CONCAT('tt', t1.`name`) AS `name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND ((t1.`id` + 1) < 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1').select('name', lambda x: (x.id + 1) * 2)[
            lambda x: 'tt' + x.name, 'id'
        ].filter(lambda x: x.id < 3)
        expected = "SELECT CONCAT('tt', t1.`name`) AS `name`, (t1.`id` + 1) * 2 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (((t1.`id` + 1) * 2) < 3) AND (t1.`name` == 'name1')"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.id.between(2, 6),
                                self.expr.name.lower().contains('pyodps', regex=False)).name.nunique()
        expected = "SELECT COUNT(DISTINCT t2.`name`) AS `name_nunique` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id`, t1.`name` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE ((t1.`id` >= 2) AND (t1.`id` <= 6)) AND INSTR(TOLOWER(t1.`name`), 'pyodps') > 0 \n" \
                   ") t2"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPushDownThroughJoin(self):
        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x < 10) & (expr.fid_y > 3)]

        expected = 'SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, ' \
                   't2.`isMale`, t2.`scale`, t2.`birth`, t2.`ds`, t4.`id` AS `id_y`, ' \
                   't4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE t1.`id` < 10\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`fid` > 3\n' \
                   '  ) t4\n' \
                   'ON t2.`name` == t4.`name`'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x < 10) & (expr.fid_y > 3) & (expr.id_x > 3)]

        expected = 'SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, ' \
                   't2.`isMale`, t2.`scale`, t2.`birth`, t2.`ds`, t4.`id` AS `id_y`, ' \
                   't4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` < 10) AND (t1.`id` > 3)\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`fid` > 3\n' \
                   '  ) t4\n' \
                   'ON t2.`name` == t4.`name`'

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr[self.expr.name, self.expr.id + 1]
        expr2 = self.expr3['tt' + self.expr3.name, self.expr3.id.rename('id2')]
        expr = expr.join(expr2, on='name')
        expr = expr[((expr.id < 10) | (expr.id > 100)) & (expr.id2 > 3)]

        expected = "SELECT t2.`name`, t2.`id`, t4.`id2` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE ((t1.`id` + 1) < 10) OR ((t1.`id` + 1) > 100)\n" \
                   ") t2 \n" \
                   "INNER JOIN \n" \
                   "  (\n" \
                   "    SELECT CONCAT('tt', t3.`name`) AS `name`, t3.`id` AS `id2` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    WHERE t3.`id` > 3\n" \
                   "  ) t4\n" \
                   "ON t2.`name` == t4.`name`"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x + expr.id_y < 10) & (expr.id_x > 3)]

        expected = "SELECT * \n" \
                   "FROM (\n" \
                   "  SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, t2.`isMale`, " \
                   "t2.`scale`, t2.`birth`, t2.`ds`, t3.`id` AS `id_y`, " \
                   "t3.`fid` AS `fid_y`, t3.`part1`, t3.`part2` \n" \
                   "  FROM (\n" \
                   "    SELECT * \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "    WHERE t1.`id` > 3\n" \
                   "  ) t2 \n" \
                   "  INNER JOIN \n" \
                   "    mocked_project.`pyodps_test_expr_table2` t3\n" \
                   "  ON t2.`name` == t3.`name` \n" \
                   ") t4 \n" \
                   "WHERE (t4.`id_x` + t4.`id_y`) < 10"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.outer_join(self.expr3, on='name')
        expr = expr[(expr.id_x + expr.id_y < 10) & (expr.id_x > 3)]

        expected = "SELECT * \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name` AS `name_x`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, " \
                   "t1.`isMale`, t1.`scale`, t1.`birth`, t1.`ds`, t2.`name` AS `name_y`, " \
                   "t2.`id` AS `id_y`, t2.`fid` AS `fid_y`, t2.`part1`, t2.`part2` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  FULL OUTER JOIN \n" \
                   "    mocked_project.`pyodps_test_expr_table2` t2\n" \
                   "  ON t1.`name` == t2.`name` \n" \
                   ") t3 \n" \
                   "WHERE ((t3.`id_x` + t3.`id_y`) < 10) AND (t3.`id_x` > 3)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on=['name', self.expr.id == self.expr3.id,
                                              self.expr.id < 10, self.expr3.name == 'name1',
                                              self.expr.id > 5])

        expected = 'SELECT t2.`name`, t2.`id`, t2.`fid` AS `fid_x`, t2.`isMale`, ' \
                   't2.`scale`, t2.`birth`, t2.`ds`, t4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` < 10) AND (t1.`id` > 5)\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`name` == \'name1\'\n' \
                   '  ) t4\n' \
                   'ON (t2.`name` == t4.`name`) AND (t2.`id` == t4.`id`)'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.left_join(self.expr3, on=['name', self.expr.id == self.expr3.id,
                                                   self.expr.id < 10, self.expr3.name == 'name1',
                                                   self.expr.id > 5])
        expected = 'SELECT t1.`name` AS `name_x`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, t1.`isMale`, ' \
                   't1.`scale`, t1.`birth`, t1.`ds`, t2.`name` AS `name_y`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`part1`, t2.`part2` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'LEFT OUTER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON ((((t1.`name` == t2.`name`) AND (t1.`id` == t2.`id`)) ' \
                   "AND (t1.`id` < 10)) AND (t2.`name` == 'name1')) AND (t1.`id` > 5)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPushdownThroughUnion(self):
        expr = self.expr['name', 'id'].union(self.expr2['id', 'name'])
        expr = expr.filter(expr.id + 1 < 3)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT t1.`name`, t1.`id` \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` + 1) < 3 \n' \
                   '  UNION ALL\n' \
                   '    SELECT t2.`name`, t2.`id` \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2 \n' \
                   '    WHERE (t2.`id` + 1) < 3\n' \
                   ') t3'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr1 = self.expr.filter(self.expr.id == 1)['name', 'id']
        expr2 = self.expr.filter(self.expr.id == 0)['id', 'name']
        expr = expr1.union(expr2)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT t1.`name`, t1.`id` \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE t1.`id` == 1 \n' \
                   '  UNION ALL\n' \
                   '    SELECT t2.`name`, t2.`id` \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table` t2 \n' \
                   '    WHERE t2.`id` == 0\n' \
                   ') t3'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(["name", "id"], datatypes("string", "int64"))
        table = MockTable(name="pyodps_test_expr_table", schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        schema2 = Schema.from_lists(["name2", "id2"], datatypes("string", "int64"))
        table2 = MockTable(name="pyodps_test_expr_table2", schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema2)

    def _lines_eq(self, expected, actual):
        self.assertSequenceEqual(
            [to_str(line.rstrip()) for line in expected.split("\n")],
            [to_str(line.rstrip()) for line in actual.split("\n")],
        )

    def testProjectionFormatter(self):
        expr = self.expr["name", self.expr.id.rename("new_id")].new_id.astype("float32")
        self._lines_eq(EXPECTED_PROJECTION_FORMAT, repr(expr))

    def testFilterFormatter(self):
        expr = self.expr[(self.expr.name != "test") & (self.expr.id > 100)]
        self._lines_eq(EXPECTED_FILTER_FORMAT, repr(expr))

    def testSliceFormatter(self):
        expr = self.expr[:100]
        self._lines_eq(EXPECTED_SLICE_FORMAT, repr(expr))

        expr = self.expr[5:100:3]
        self._lines_eq(EXPECTED_SLICE_WITH_START_STEP_FORMAT, repr(expr))

    def testArithmeticFormatter(self):
        expr = self.expr
        d = -(expr["id"]) + 20.34 - expr["id"] + float(20) * expr["id"] - expr["id"] / 4.9 + 40 // 2 + expr["id"] // 1.2

        try:
            self._lines_eq(EXPECTED_ARITHMETIC_FORMAT, repr(d))
        except AssertionError as e:
            left = [to_str(line.rstrip()) for line in EXPECTED_ARITHMETIC_FORMAT.split("\n")]
            right = [to_str(line.rstrip()) for line in repr(d).split("\n")]
            self.assertEqual(len(left), len(right))
            for l, r in zip(left, right):
                try:
                    self.assertEqual(l, r)
                except AssertionError:
                    try:
                        self.assertAlmostEqual(float(l), float(r))
                    except:
                        raise e

    def testSortFormatter(self):
        expr = self.expr.sort(["name", -self.expr.id])

        self._lines_eq(EXPECTED_SORT_FORMAT, repr(expr))

    def testDistinctFormatter(self):
        expr = self.expr.distinct(["name", self.expr.id + 1])

        self._lines_eq(EXPECTED_DISTINCT_FORMAT, repr(expr))

    def testGroupbyFormatter(self):
        expr = self.expr.groupby(["name", "id"]).agg(new_id=self.expr.id.sum())

        self._lines_eq(EXPECTED_GROUPBY_FORMAT, repr(expr))

        grouped = self.expr.groupby(["name"])
        expr = grouped.mutate(grouped.row_number())

        self._lines_eq(EXPECTED_MUTATE_FORMAT, repr(expr))

        expr = self.expr.groupby(["name", "id"]).count()
        self._lines_eq(EXPECTED_GROUPBY_COUNT_FORMAT, repr(expr))

    def testReductionFormatter(self):
        expr = self.expr.groupby(["id"]).id.std()

        self._lines_eq(EXPECTED_REDUCTION_FORMAT, repr(expr))

        expr = self.expr.id.mean()
        self._lines_eq(EXPECTED_REDUCTION_FORMAT2, repr(expr))

        expr = self.expr.count()
        self._lines_eq(EXPECTED_REDUCTION_FORMAT3, repr(expr))

    def testWindowFormatter(self):
        expr = self.expr.groupby(["name"]).sort(-self.expr.id).name.rank()

        self._lines_eq(EXPECTED_WINDOW_FORMAT1, repr(expr))

        expr = self.expr.groupby(["id"]).id.cummean(preceding=10, following=5, unique=True)

        self._lines_eq(EXPECTED_WINDOW_FORMAT2, repr(expr))

    def testElementFormatter(self):
        expr = self.expr.name.contains("test")

        self._lines_eq(EXPECTED_STRING_FORMAT, repr(expr))

        expr = self.expr.id.between(1, 3)

        self._lines_eq(EXPECTED_ELEMENT_FORMAT, repr(expr))

        expr = self.expr.name.astype("datetime").strftime("%Y")

        self._lines_eq(EXPECTED_DATETIME_FORMAT, repr(expr))

        expr = self.expr.id.switch(3, self.expr.name, 4, self.expr.name + "abc", default=self.expr.name + "test")
        self._lines_eq(EXPECTED_SWITCH_FORMAT, repr(expr))

    def testJoinFormatter(self):
        expr = self.expr.join(self.expr2, ("name", "name2"))
        self._lines_eq(EXPECTED_JOIN_FORMAT, repr(expr))

    def testAstypeFormatter(self):
        expr = self.expr.id.astype("float")
        self._lines_eq(EXPECTED_CAST_FORMAT, repr(expr))
Ejemplo n.º 39
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
            datatypes('string', 'int64', 'float64', 'boolean', 'decimal',
                      'datetime'), ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=schema)

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema)

        schema2 = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'],
                                    datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=schema2)

        self.maxDiff = None

    def testFilterPushdownThroughProjection(self):
        expr = self.expr[self.expr.id + 1, 'name'][lambda x: x.id < 10]

        expected = 'SELECT t1.`id` + 1 AS `id`, t1.`name` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE (t1.`id` + 1) < 10'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr['name', self.expr.id ** 2]\
            .filter(lambda x: x.name == 'name1').filter(lambda x: x.id < 3)
        expected = "SELECT t1.`name`, CAST(POW(t1.`id`, 2) AS BIGINT) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND ((CAST(POW(t1.`id`, 2) AS BIGINT)) < 3)"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr['name', self.expr.id + 1].filter(
            lambda x: x.name == 'name1')[lambda x: 'tt' + x.name,
                                         'id'].filter(lambda x: x.id < 3)

        expected = "SELECT CONCAT('tt', t1.`name`) AS `name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (t1.`name` == 'name1') AND ((t1.`id` + 1) < 3)"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1').select(
            'name', lambda x: (x.id + 1) * 2)[lambda x: 'tt' + x.name,
                                              'id'].filter(lambda x: x.id < 3)
        expected = "SELECT CONCAT('tt', t1.`name`) AS `name`, (t1.`id` + 1) * 2 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE (((t1.`id` + 1) * 2) < 3) AND (t1.`name` == 'name1')"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPushDownThroughJoin(self):
        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x < 10) & (expr.fid_y > 3)]

        expected = 'SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, ' \
                   't2.`isMale`, t2.`scale`, t2.`birth`, t2.`ds`, t4.`id` AS `id_y`, ' \
                   't4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE t1.`id` < 10\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`fid` > 3\n' \
                   '  ) t4\n' \
                   'ON t2.`name` == t4.`name`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x < 10) & (expr.fid_y > 3) & (expr.id_x > 3)]

        expected = 'SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, ' \
                   't2.`isMale`, t2.`scale`, t2.`birth`, t2.`ds`, t4.`id` AS `id_y`, ' \
                   't4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` < 10) AND (t1.`id` > 3)\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`fid` > 3\n' \
                   '  ) t4\n' \
                   'ON t2.`name` == t4.`name`'

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr[self.expr.name, self.expr.id + 1]
        expr2 = self.expr3['tt' + self.expr3.name, self.expr3.id.rename('id2')]
        expr = expr.join(expr2, on='name')
        expr = expr[((expr.id < 10) | (expr.id > 100)) & (expr.id2 > 3)]

        expected = "SELECT t2.`name`, t2.`id`, t4.`id2` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE ((t1.`id` + 1) < 10) OR ((t1.`id` + 1) > 100)\n" \
                   ") t2 \n" \
                   "INNER JOIN \n" \
                   "  (\n" \
                   "    SELECT CONCAT('tt', t3.`name`) AS `name`, t3.`id` AS `id2` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    WHERE t3.`id` > 3\n" \
                   "  ) t4\n" \
                   "ON t2.`name` == t4.`name`"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3, on='name')
        expr = expr[(expr.id_x + expr.id_y < 10) & (expr.id_x > 3)]

        expected = "SELECT * \n" \
                   "FROM (\n" \
                   "  SELECT t2.`name`, t2.`id` AS `id_x`, t2.`fid` AS `fid_x`, t2.`isMale`, " \
                   "t2.`scale`, t2.`birth`, t2.`ds`, t3.`id` AS `id_y`, " \
                   "t3.`fid` AS `fid_y`, t3.`part1`, t3.`part2` \n" \
                   "  FROM (\n" \
                   "    SELECT * \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "    WHERE t1.`id` > 3\n" \
                   "  ) t2 \n" \
                   "  INNER JOIN \n" \
                   "    mocked_project.`pyodps_test_expr_table2` t3\n" \
                   "  ON t2.`name` == t3.`name` \n" \
                   ") t4 \n" \
                   "WHERE (t4.`id_x` + t4.`id_y`) < 10"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.join(self.expr3,
                              on=[
                                  'name', self.expr.id == self.expr3.id,
                                  self.expr.id < 10,
                                  self.expr3.name == 'name1', self.expr.id > 5
                              ])

        expected = 'SELECT t2.`name`, t2.`id`, t2.`fid` AS `fid_x`, t2.`isMale`, ' \
                   't2.`scale`, t2.`birth`, t2.`ds`, t4.`fid` AS `fid_y`, t4.`part1`, t4.`part2` \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` < 10) AND (t1.`id` > 5)\n' \
                   ') t2 \n' \
                   'INNER JOIN \n' \
                   '  (\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t3 \n' \
                   '    WHERE t3.`name` == \'name1\'\n' \
                   '  ) t4\n' \
                   'ON (t2.`name` == t4.`name`) AND (t2.`id` == t4.`id`)'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPushdownThroughUnion(self):
        expr = self.expr['name', 'id'].union(self.expr2['id', 'name'])
        expr = expr.filter(expr.id + 1 < 3)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT t1.`name`, t1.`id` \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  WHERE (t1.`id` + 1) < 3 \n' \
                   '  UNION ALL\n' \
                   '    SELECT t2.`name`, t2.`id` \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2 \n' \
                   '    WHERE (t2.`id` + 1) < 3\n' \
                   ') t3'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))
 def setup(self):
     from odps.df.expr.tests.core import MockTable
     schema = Schema.from_lists(types._data_types.keys(), types._data_types.values())
     self.expr = CollectionExpr(_source_data=None, _schema=schema)
     self.sourced_expr = CollectionExpr(_source_data=MockTable(client=self.odps.rest), _schema=schema)
Ejemplo n.º 41
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'))

        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        self.engine = ODPSEngine(self.odps)

    def testSimpleLambda(self):
        self.engine.compile(self.expr.id.map(lambda x: x + 1))
        udf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udf)
        udf = locals()[UDF_CLASS_NAME]
        self.assertSequenceEqual([4, ], runners.simple_run(udf, [(3, ), ]))

    def testSimpleFunction(self):
        def my_func(x):
            if x < 0:
                return -1
            elif x == 0:
                return 0
            else:
                return 1

        self.engine.compile(self.expr.id.map(my_func))
        udf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udf)
        udf = locals()[UDF_CLASS_NAME]
        self.assertSequenceEqual([-1, 0, 1], runners.simple_run(udf, [(-3, ), (0, ), (5, )]))

    def testNestFunction(self):
        def my_func(x):
            def inner(y):
                if y < 0:
                    return -2
                elif y == 0:
                    return 0
                else:
                    return 2
            return inner(x)

        self.engine.compile(self.expr.id.map(my_func))
        udf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udf, globals(), locals())
        udf = locals()[UDF_CLASS_NAME]
        self.assertSequenceEqual([-2, 0, 2], runners.simple_run(udf, [(-3, ), (0, ), (5, )]))

    def testGlobalVarFunction(self):
        global_val = 10
        def my_func(x):
            if x < global_val:
                return -1
            elif x == global_val:
                return 0
            else:
                return 1

        self.engine.compile(self.expr.id.map(my_func))
        udf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udf, globals(), locals())
        udf = locals()[UDF_CLASS_NAME]
        self.assertSequenceEqual([-1, 0, 1], runners.simple_run(udf, [(-9, ), (10, ), (15, )]))

    def testRefFuncFunction(self):
        global_val = 10
        def my_func1(x):
            if x < global_val:
                return -1
            elif x == global_val:
                return 0
            else:
                return 1
        def my_func(y):
            return my_func1(y)

        self.engine.compile(self.expr.id.map(my_func))
        udf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udf, globals(), locals())
        udf = locals()[UDF_CLASS_NAME]
        self.assertSequenceEqual([-1, 0, 1], runners.simple_run(udf, [(-9, ), (10, ), (15, )]))

    def testApplyToSequenceFuntion(self):
        def my_func(row):
            return row.name + str(row.id)

        self.engine.compile(self.expr.apply(my_func, axis=1, reduce=True).rename('test'))
        udf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udf, globals(), locals())
        udf = locals()[UDF_CLASS_NAME]
        self.assertEqual(['name1', 'name2'],
                         runners.simple_run(udf, [('name', 1, None), ('name', 2, None)]))

    def testApplyFunction(self):
        def my_func(row):
            return row.name, row.id

        self.engine.compile(self.expr.apply(my_func, axis=1, names=['name', 'id'], types=['string', 'int']))
        udtf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udtf, globals(), locals())
        udtf = locals()[UDF_CLASS_NAME]
        self.assertEqual([('name1', 1), ('name2', 2)],
                          runners.simple_run(udtf, [('name1', 1, None), ('name2', 2, None)]))

    def testApplyGeneratorFunction(self):
        def my_func(row):
            for n in row.name.split(','):
                yield n

        self.engine.compile(self.expr.apply(my_func, axis=1, names='name'))
        udtf = list(self.engine._ctx._func_to_udfs.values())[0]
        six.exec_(udtf, globals(), locals())
        udtf = locals()[UDF_CLASS_NAME]
        self.assertEqual(['name1', 'name2', 'name3', 'name4'],
                         runners.simple_run(udtf, [('name1,name2', 1, None), ('name3,name4', 2, None)]))
    def testFilterPushdownThroughMultipleProjection(self):
        schema = Schema.from_lists(list('abcde'), ['string']*5)
        table = MockTable(name='pyodps_test_expr_table3', schema=schema)
        tab = CollectionExpr(_source_data=table, _schema=odps_schema_to_df_schema(schema))

        labels2 = []
        bins2 = []
        for i in range(0, 30):
            a = str(7 * i) + '-' + str(7 * (i + 1))
            b = 7 * i
            bins2.append(b)
            labels2.append(a)

        p1 = tab.select(tab.a,
                        tab.c.astype('int').cut(bins2, labels=labels2, include_over=True).rename('c_cut'),
                        tab.e.astype('int').rename('e'),
                        tab.c.astype('int').rename('c'))
        p1['f'] = p1['e'] / p1['c']
        t = []
        l = []
        for i in range(0, 20):
            a = 1 * i
            b = str(a)
            t.append(a)
            l.append(b)
        p2 = p1.select(p1.a, p1.c_cut, p1.f.cut(bins=t, labels=l, include_over=True).rename('f_cut'))

        expected = "SELECT t1.`a`, CASE WHEN (0 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 7) THEN '0-7' " \
                   "WHEN (7 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 14) " \
                   "THEN '7-14' WHEN (14 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 21) THEN '14-21' " \
                   "WHEN (21 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 28) " \
                   "THEN '21-28' WHEN (28 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 35) THEN '28-35' " \
                   "WHEN (35 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 42) THEN '35-42' " \
                   "WHEN (42 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 49) THEN '42-49' " \
                   "WHEN (49 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 56) " \
                   "THEN '49-56' WHEN (56 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 63) THEN '56-63' " \
                   "WHEN (63 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 70) THEN '63-70' " \
                   "WHEN (70 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 77) " \
                   "THEN '70-77' WHEN (77 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 84) " \
                   "THEN '77-84' WHEN (84 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 91) THEN '84-91' " \
                   "WHEN (91 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 98) " \
                   "THEN '91-98' WHEN (98 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 105) THEN '98-105' " \
                   "WHEN (105 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 112) " \
                   "THEN '105-112' WHEN (112 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 119) THEN '112-119' " \
                   "WHEN (119 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 126) " \
                   "THEN '119-126' WHEN (126 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 133) THEN '126-133' " \
                   "WHEN (133 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 140) " \
                   "THEN '133-140' WHEN (140 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 147) THEN '140-147' " \
                   "WHEN (147 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 154) " \
                   "THEN '147-154' WHEN (154 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 161) THEN '154-161' " \
                   "WHEN (161 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 168) " \
                   "THEN '161-168' WHEN (168 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 175) THEN '168-175' " \
                   "WHEN (175 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 182) " \
                   "THEN '175-182' WHEN (182 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 189) THEN '182-189' " \
                   "WHEN (189 < CAST(t1.`c` AS BIGINT)) AND (CAST(t1.`c` AS BIGINT) <= 196) " \
                   "THEN '189-196' WHEN (196 < CAST(t1.`c` AS BIGINT)) " \
                   "AND (CAST(t1.`c` AS BIGINT) <= 203) THEN '196-203' " \
                   "WHEN 203 < CAST(t1.`c` AS BIGINT) THEN '203-210' END AS `c_cut`, " \
                   "CASE WHEN (0 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 1) THEN '0' " \
                   "WHEN (1 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 2) " \
                   "THEN '1' WHEN (2 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 3) THEN '2' " \
                   "WHEN (3 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 4) " \
                   "THEN '3' WHEN (4 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 5) THEN '4' " \
                   "WHEN (5 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 6) THEN '5' " \
                   "WHEN (6 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 7) " \
                   "THEN '6' WHEN (7 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 8) THEN '7' " \
                   "WHEN (8 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 9) THEN '8' " \
                   "WHEN (9 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 10) " \
                   "THEN '9' WHEN (10 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 11) THEN '10' " \
                   "WHEN (11 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 12) " \
                   "THEN '11' WHEN (12 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 13) THEN '12' " \
                   "WHEN (13 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 14) THEN '13' " \
                   "WHEN (14 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 15) THEN '14' " \
                   "WHEN (15 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 16) THEN '15' " \
                   "WHEN (16 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 17) THEN '16' " \
                   "WHEN (17 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 18) " \
                   "THEN '17' WHEN (18 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 19) THEN '18' " \
                   "WHEN 19 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) THEN '19' END AS `f_cut` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table3` t1 \n" \
                   "WHERE (CASE WHEN (0 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 1) THEN '0' " \
                   "WHEN (1 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 2) " \
                   "THEN '1' WHEN (2 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 3) THEN '2' " \
                   "WHEN (3 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 4) THEN '3' " \
                   "WHEN (4 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 5) THEN '4' " \
                   "WHEN (5 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 6) THEN '5' " \
                   "WHEN (6 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 7) THEN '6' " \
                   "WHEN (7 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 8) THEN '7' " \
                   "WHEN (8 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 9) THEN '8' " \
                   "WHEN (9 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 10) THEN '9' " \
                   "WHEN (10 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 11) THEN '10' " \
                   "WHEN (11 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 12) THEN '11' " \
                   "WHEN (12 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 13) THEN '12' " \
                   "WHEN (13 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 14) THEN '13' " \
                   "WHEN (14 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 15) THEN '14' " \
                   "WHEN (15 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 16) THEN '15' " \
                   "WHEN (16 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 17) THEN '16' " \
                   "WHEN (17 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 18) THEN '17' " \
                   "WHEN (18 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT))) " \
                   "AND ((CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) <= 19) THEN '18' " \
                   "WHEN 19 < (CAST(t1.`e` AS BIGINT) / CAST(t1.`c` AS BIGINT)) THEN '19' END) == '9'"

        self.assertEqual(str(expected), str(ODPSEngine(self.odps).compile(p2[p2.f_cut == '9'], prettify=False)))
Ejemplo n.º 43
0
 def get_table2_df(self):
     schema = Schema.from_lists(['col21', 'col22'],
                                datatypes('string', 'string'))
     table = MockTable(name=TEMP_TABLE_2_NAME, schema=schema)
     return CollectionExpr(_source_data=table, _schema=schema)
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'),
                                   ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=Schema(columns=schema.columns))

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=Schema(columns=schema.columns))

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=Schema(columns=schema.columns))

        schema2 = Schema.from_lists(['name', 'id', 'fid'], datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'], datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=Schema(columns=schema2.columns))

    def testProjectPrune(self):
        expr = self.expr.select('name', 'id')
        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(expected, ODPSEngine(self.odps).compile(expr, prettify=False))

        expr = self.expr[Scalar(3).rename('const'),
                         NullScalar('string').rename('string_const'),
                         self.expr.id]
        expected = 'SELECT 3 AS `const`, CAST(NULL AS STRING) AS `string_const`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.select(pt=BuiltinFunction('max_pt', args=(self.expr._source_data.name,)))
        expected = "SELECT max_pt('pyodps_test_expr_table') AS `pt` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testApplyPrune(self):
        @output(['name', 'id'], ['string', 'string'])
        def h(row):
            yield row[0], row[1]

        expr = self.expr[self.expr.fid < 0].apply(h, axis=1)['id', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsInstance(new_expr.input.input, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input.input.input._source_data)

    def testFilterPrune(self):
        expr = self.expr.filter(self.expr.name == 'name1')
        expr = expr['name', 'id']

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input, FilterCollectionExpr)
        self.assertNotIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`name` == \'name1\''
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1')

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expr = self.expr.filter(self.expr.id.isin(self.expr3.id))

        expected = 'SELECT * \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`id` IN (SELECT t3.`id` FROM (  ' \
                   'SELECT t2.`id`   FROM mocked_project.`pyodps_test_expr_table2` t2 ) t3)'
        self.assertTrue(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPartitionPrune(self):
        expr = self.expr.filter_partition('ds=today')[lambda x: x.fid < 0][
            'name', lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertEqual(set(new_expr.input.input.schema.names), set(['name', 'id', 'fid']))

        expected = "SELECT t2.`name`, t2.`id` + 1 AS `id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id`, t1.`fid` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE t1.`ds` == 'today' \n" \
                   ") t2 \n" \
                   "WHERE t2.`fid` < 0"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSlicePrune(self):
        expr = self.expr.filter(self.expr.fid < 0)[:4]['name', lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsNotNone(new_expr.input.input.input._source_data)

        expected = "SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE t1.`fid` < 0 \n" \
                   "LIMIT 4"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testGroupbyPrune(self):
        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['name', ]

        expected = "SELECT t1.`name` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['id',]

        expected = "SELECT MAX(t1.`id`) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testMutatePrune(self):
        expr = self.expr[self.expr.exclude('birth'), self.expr.fid.astype('int').rename('new_id')]
        expr = expr[expr, expr.groupby('name').mutate(lambda x: x.new_id.cumsum().rename('new_id_sum'))]
        expr = expr[expr.new_id, expr.new_id_sum]

        expected = "SELECT t2.`new_id`, t2.`new_id_sum` \n" \
                   "FROM (\n" \
                   "  SELECT CAST(t1.`fid` AS BIGINT) AS `new_id`, " \
                   "SUM(CAST(t1.`fid` AS BIGINT)) OVER (PARTITION BY t1.`name`) AS `new_id_sum` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testValueCountsPrune(self):
        expr = self.expr.name.value_counts()['count', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertEqual(set(new_expr.input.input.schema.names), set(['name']))

    def testSortPrune(self):
        expr = self.expr[self.expr.exclude('name'), self.expr.name.rename('name2')].sort('name2')['id', 'fid']

        expected = "SELECT t2.`id`, t2.`fid` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id`, t1.`fid`, t1.`name` AS `name2` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  ORDER BY name2 \n" \
                   "  LIMIT 10000\n" \
                   ") t2"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testDistinctPrune(self):
        expr = self.expr.distinct(self.expr.id + 1, self.expr.name)['name', ]

        expected = "SELECT t2.`name` \n" \
                   "FROM (\n" \
                   "  SELECT DISTINCT t1.`id` + 1 AS `id`, t1.`name` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSamplePrune(self):
        expr = self.expr['name', 'id'].sample(parts=5)['id', ]

        expected = "SELECT t1.`id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE SAMPLE(5, 1)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testJoinPrune(self):
        left = self.expr.select(self.expr, type='normal')
        right = self.expr3[:4]
        joined = left.left_join(right, on='id')
        expr = joined.id_x.rename('id')

        expected = "SELECT t2.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1\n" \
                   ") t2 \n" \
                   "LEFT OUTER JOIN \n" \
                   "  (\n" \
                   "    SELECT t3.`id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    LIMIT 4\n" \
                   "  ) t4\n" \
                   "ON t2.`id` == t4.`id`"

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        joined = self.expr.join(self.expr2, 'name')

        expected = 'SELECT t1.`name`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, ' \
                   't1.`isMale` AS `isMale_x`, t1.`scale` AS `scale_x`, ' \
                   't1.`birth` AS `birth_x`, t1.`ds` AS `ds_x`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name`'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(joined, prettify=False)))

    def testUnionPrune(self):
        left = self.expr.select('name', 'id')
        right = self.expr3.select(self.expr3.fid.astype('int').rename('id'), self.expr3.name)
        expr = left.union(right)['id']

        expected = "SELECT t3.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  UNION ALL\n" \
                   "    SELECT CAST(t2.`fid` AS BIGINT) AS `id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t2\n" \
                   ") t3"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.union(self.expr2)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  UNION ALL\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2\n' \
                   ') t3'

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))