Ejemplo n.º 1
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(
            ['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
            datatypes('string', 'int64', 'float64', 'boolean', 'decimal',
                      'datetime'), ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table,
                                   _schema=Schema(columns=schema.columns))

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1,
                                    _schema=Schema(columns=schema.columns))

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2,
                                    _schema=Schema(columns=schema.columns))

        schema2 = Schema.from_lists(['name', 'id', 'fid'],
                                    datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'],
                                    datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3,
                                    _schema=Schema(columns=schema2.columns))

        schema3 = Schema.from_lists(['id', 'name', 'relatives', 'hobbies'],
                                    datatypes('int64', 'string',
                                              'dict<string, string>',
                                              'list<string>'))
        table4 = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr4 = CollectionExpr(_source_data=table4, _schema=schema3)

    def testProjectPrune(self):
        expr = self.expr.select('name', 'id')
        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(expected,
                         ODPSEngine(self.odps).compile(expr, prettify=False))

        expr = self.expr[Scalar(3).rename('const'),
                         NullScalar('string').rename('string_const'),
                         self.expr.id]
        expected = 'SELECT 3 AS `const`, CAST(NULL AS STRING) AS `string_const`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.select(
            pt=BuiltinFunction('max_pt', args=(self.expr._source_data.name, )))
        expected = "SELECT max_pt('pyodps_test_expr_table') AS `pt` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testApplyPrune(self):
        @output(['name', 'id'], ['string', 'string'])
        def h(row):
            yield row[0], row[1]

        expr = self.expr[self.expr.fid < 0].apply(h, axis=1)['id', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsInstance(new_expr.input.input, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input.input.input._source_data)

    def testFilterPrune(self):
        expr = self.expr.filter(self.expr.name == 'name1')
        expr = expr['name', 'id']

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input, FilterCollectionExpr)
        self.assertNotIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`name` == \'name1\''
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1')

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expr = self.expr.filter(self.expr.id.isin(self.expr3.id))

        expected = 'SELECT * \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`id` IN (SELECT t3.`id` FROM (  ' \
                   'SELECT t2.`id`   FROM mocked_project.`pyodps_test_expr_table2` t2 ) t3)'
        self.assertTrue(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPartsPrune(self):
        expr = self.expr.filter_parts('ds=today')[lambda x: x.fid < 0][
            'name', lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertEqual(set(new_expr.input.input.schema.names),
                         set(['name', 'id', 'fid']))

        expected = "SELECT t2.`name`, t2.`id` + 1 AS `id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id`, t1.`fid` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE t1.`ds` == 'today' \n" \
                   ") t2 \n" \
                   "WHERE t2.`fid` < 0"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSlicePrune(self):
        expr = self.expr.filter(self.expr.fid < 0)[:4]['name',
                                                       lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsNotNone(new_expr.input.input.input._source_data)

        expected = "SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE t1.`fid` < 0 \n" \
                   "LIMIT 4"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testGroupbyPrune(self):
        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['name', ]

        expected = "SELECT t1.`name` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['id', ]

        expected = "SELECT MAX(t1.`id`) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testMutatePrune(self):
        expr = self.expr[self.expr.exclude('birth'),
                         self.expr.fid.astype('int').rename('new_id')]
        expr = expr[expr,
                    expr.groupby('name').
                    mutate(lambda x: x.new_id.cumsum().rename('new_id_sum'))]
        expr = expr[expr.new_id, expr.new_id_sum]

        expected = "SELECT t2.`new_id`, t2.`new_id_sum` \n" \
                   "FROM (\n" \
                   "  SELECT CAST(t1.`fid` AS BIGINT) AS `new_id`, " \
                   "SUM(CAST(t1.`fid` AS BIGINT)) OVER (PARTITION BY t1.`name`) AS `new_id_sum` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testValueCountsPrune(self):
        expr = self.expr.name.value_counts()['count', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertEqual(set(new_expr.input.input.schema.names), set(['name']))

    def testSortPrune(self):
        expr = self.expr[self.expr.exclude('name'),
                         self.expr.name.rename('name2')].sort('name2')['id',
                                                                       'fid']

        expected = "SELECT t2.`id`, t2.`fid` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id`, t1.`fid`, t1.`name` AS `name2` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  ORDER BY name2 \n" \
                   "  LIMIT 10000\n" \
                   ") t2"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testDistinctPrune(self):
        expr = self.expr.distinct(self.expr.id + 1, self.expr.name)['name', ]

        expected = "SELECT t2.`name` \n" \
                   "FROM (\n" \
                   "  SELECT DISTINCT t1.`id` + 1 AS `id`, t1.`name` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSamplePrune(self):
        expr = self.expr['name', 'id'].sample(parts=5)['id', ]

        expected = "SELECT t1.`id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE SAMPLE(5, 1)"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testJoinPrune(self):
        left = self.expr.select(self.expr, type='normal')
        right = self.expr3[:4]
        joined = left.left_join(right, on='id')
        expr = joined.id_x.rename('id')

        expected = "SELECT t2.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1\n" \
                   ") t2 \n" \
                   "LEFT OUTER JOIN \n" \
                   "  (\n" \
                   "    SELECT t3.`id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    LIMIT 4\n" \
                   "  ) t4\n" \
                   "ON t2.`id` == t4.`id`"

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        joined = self.expr.join(self.expr2, 'name')

        expected = 'SELECT t1.`name`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, ' \
                   't1.`isMale` AS `isMale_x`, t1.`scale` AS `scale_x`, ' \
                   't1.`birth` AS `birth_x`, t1.`ds` AS `ds_x`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(joined, prettify=False)))

        joined = self.expr.join(self.expr2,
                                on=[self.expr.name == self.expr2.name])
        joined2 = joined.join(self.expr, on=[joined.id_x == self.expr.id])

        expected = 'SELECT t1.`name` AS `name_x`, t1.`id` AS `id_x`, ' \
                   't1.`fid` AS `fid_x`, t1.`isMale` AS `isMale_x`, ' \
                   't1.`scale` AS `scale_x`, t1.`birth` AS `birth_x`, ' \
                   't1.`ds` AS `ds_x`, t2.`id` AS `id_y`, t2.`fid` AS `fid_y`, ' \
                   't2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y`, ' \
                   't3.`name` AS `name_y`, t3.`id`, t3.`fid`, t3.`isMale`, ' \
                   't3.`scale`, t3.`birth`, t3.`ds` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name` \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table` t3\n' \
                   'ON t1.`id` == t3.`id`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(joined2, prettify=False)))

        joined = self.expr.join(self.expr2,
                                on=[self.expr.name == self.expr2.name],
                                mapjoin=True)
        joined2 = joined.join(self.expr,
                              on=[joined.id_x == self.expr.id],
                              mapjoin=True)

        expected = 'SELECT /*+mapjoin(t2,t3)*/ t1.`name` AS `name_x`, t1.`id` AS `id_x`, ' \
                   't1.`fid` AS `fid_x`, t1.`isMale` AS `isMale_x`, t1.`scale` AS `scale_x`, ' \
                   't1.`birth` AS `birth_x`, t1.`ds` AS `ds_x`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y`, t3.`name` AS `name_y`, ' \
                   't3.`id`, t3.`fid`, t3.`isMale`, t3.`scale`, t3.`birth`, t3.`ds` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name` \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table` t3\n' \
                   'ON t1.`id` == t3.`id`'
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(joined2, prettify=False)))

    def testUnionPrune(self):
        left = self.expr.select('name', 'id')
        right = self.expr3.select(
            self.expr3.fid.astype('int').rename('id'), self.expr3.name)
        expr = left.union(right)['id']

        expected = "SELECT t3.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  UNION ALL\n" \
                   "    SELECT CAST(t2.`fid` AS BIGINT) AS `id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t2\n" \
                   ") t3"
        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.union(self.expr2)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  UNION ALL\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2\n' \
                   ') t3'

        self.assertEqual(
            to_str(expected),
            to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testLateralViewPrune(self):
        expr = self.expr4['name', 'id', self.expr4.hobbies.explode()]
        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsInstance(new_expr, LateralViewCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id`, t2.`hobbies` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'LATERAL VIEW EXPLODE(t1.`hobbies`) t2 AS `hobbies`'
        self.assertEqual(expected,
                         ODPSEngine(self.odps).compile(expr, prettify=False))

        expected = 'SELECT t1.`id`, t2.`hobbies` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'LATERAL VIEW EXPLODE(t1.`hobbies`) t2 AS `hobbies`'

        expr2 = expr[expr.id, expr.hobbies]
        self.assertEqual(expected,
                         ODPSEngine(self.odps).compile(expr2, prettify=False))
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id', 'fid', 'isMale', 'scale', 'birth'],
                                   datatypes('string', 'int64', 'float64', 'boolean', 'decimal', 'datetime'),
                                   ['ds'], datatypes('string'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)
        self.expr = CollectionExpr(_source_data=table, _schema=Schema(columns=schema.columns))

        table1 = MockTable(name='pyodps_test_expr_table1', schema=schema)
        self.expr1 = CollectionExpr(_source_data=table1, _schema=Schema(columns=schema.columns))

        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=Schema(columns=schema.columns))

        schema2 = Schema.from_lists(['name', 'id', 'fid'], datatypes('string', 'int64', 'float64'),
                                    ['part1', 'part2'], datatypes('string', 'int64'))
        table3 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr3 = CollectionExpr(_source_data=table3, _schema=Schema(columns=schema2.columns))

    def testProjectPrune(self):
        expr = self.expr.select('name', 'id')
        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(expected, ODPSEngine(self.odps).compile(expr, prettify=False))

        expr = self.expr[Scalar(3).rename('const'),
                         NullScalar('string').rename('string_const'),
                         self.expr.id]
        expected = 'SELECT 3 AS `const`, CAST(NULL AS STRING) AS `string_const`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.select(pt=BuiltinFunction('max_pt', args=(self.expr._source_data.name,)))
        expected = "SELECT max_pt('pyodps_test_expr_table') AS `pt` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testApplyPrune(self):
        @output(['name', 'id'], ['string', 'string'])
        def h(row):
            yield row[0], row[1]

        expr = self.expr[self.expr.fid < 0].apply(h, axis=1)['id', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, ProjectCollectionExpr)
        self.assertIsInstance(new_expr.input.input, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input.input.input._source_data)

    def testFilterPrune(self):
        expr = self.expr.filter(self.expr.name == 'name1')
        expr = expr['name', 'id']

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input, FilterCollectionExpr)
        self.assertNotIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertIsNotNone(new_expr.input.input._source_data)

        expected = 'SELECT t1.`name`, t1.`id` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`name` == \'name1\''
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.filter(self.expr.name == 'name1')

        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr, FilterCollectionExpr)
        self.assertIsNotNone(new_expr.input._source_data)

        expr = self.expr.filter(self.expr.id.isin(self.expr3.id))

        expected = 'SELECT * \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'WHERE t1.`id` IN (SELECT t3.`id` FROM (  ' \
                   'SELECT t2.`id`   FROM mocked_project.`pyodps_test_expr_table2` t2 ) t3)'
        self.assertTrue(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testFilterPartitionPrune(self):
        expr = self.expr.filter_partition('ds=today')[lambda x: x.fid < 0][
            'name', lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertEqual(set(new_expr.input.input.schema.names), set(['name', 'id', 'fid']))

        expected = "SELECT t2.`name`, t2.`id` + 1 AS `id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`name`, t1.`id`, t1.`fid` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  WHERE t1.`ds` == 'today' \n" \
                   ") t2 \n" \
                   "WHERE t2.`fid` < 0"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSlicePrune(self):
        expr = self.expr.filter(self.expr.fid < 0)[:4]['name', lambda x: x.id + 1]

        new_expr = ColumnPruning(expr.to_dag()).prune()
        self.assertIsNotNone(new_expr.input.input.input._source_data)

        expected = "SELECT t1.`name`, t1.`id` + 1 AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE t1.`fid` < 0 \n" \
                   "LIMIT 4"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testGroupbyPrune(self):
        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['name', ]

        expected = "SELECT t1.`name` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.groupby('name').agg(id=self.expr.id.max())
        expr = expr[expr.id < 0]['id',]

        expected = "SELECT MAX(t1.`id`) AS `id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "GROUP BY t1.`name` \n" \
                   "HAVING MAX(t1.`id`) < 0"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testMutatePrune(self):
        expr = self.expr[self.expr.exclude('birth'), self.expr.fid.astype('int').rename('new_id')]
        expr = expr[expr, expr.groupby('name').mutate(lambda x: x.new_id.cumsum().rename('new_id_sum'))]
        expr = expr[expr.new_id, expr.new_id_sum]

        expected = "SELECT t2.`new_id`, t2.`new_id_sum` \n" \
                   "FROM (\n" \
                   "  SELECT CAST(t1.`fid` AS BIGINT) AS `new_id`, " \
                   "SUM(CAST(t1.`fid` AS BIGINT)) OVER (PARTITION BY t1.`name`) AS `new_id_sum` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testValueCountsPrune(self):
        expr = self.expr.name.value_counts()['count', ]
        new_expr = ColumnPruning(expr.to_dag()).prune()

        self.assertIsInstance(new_expr.input.input, ProjectCollectionExpr)
        self.assertEqual(set(new_expr.input.input.schema.names), set(['name']))

    def testSortPrune(self):
        expr = self.expr[self.expr.exclude('name'), self.expr.name.rename('name2')].sort('name2')['id', 'fid']

        expected = "SELECT t2.`id`, t2.`fid` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id`, t1.`fid`, t1.`name` AS `name2` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  ORDER BY name2 \n" \
                   "  LIMIT 10000\n" \
                   ") t2"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testDistinctPrune(self):
        expr = self.expr.distinct(self.expr.id + 1, self.expr.name)['name', ]

        expected = "SELECT t2.`name` \n" \
                   "FROM (\n" \
                   "  SELECT DISTINCT t1.`id` + 1 AS `id`, t1.`name` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   ") t2"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testSamplePrune(self):
        expr = self.expr['name', 'id'].sample(parts=5)['id', ]

        expected = "SELECT t1.`id` \n" \
                   "FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "WHERE SAMPLE(5, 1)"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

    def testJoinPrune(self):
        left = self.expr.select(self.expr, type='normal')
        right = self.expr3[:4]
        joined = left.left_join(right, on='id')
        expr = joined.id_x.rename('id')

        expected = "SELECT t2.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1\n" \
                   ") t2 \n" \
                   "LEFT OUTER JOIN \n" \
                   "  (\n" \
                   "    SELECT t3.`id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t3 \n" \
                   "    LIMIT 4\n" \
                   "  ) t4\n" \
                   "ON t2.`id` == t4.`id`"

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        joined = self.expr.join(self.expr2, 'name')

        expected = 'SELECT t1.`name`, t1.`id` AS `id_x`, t1.`fid` AS `fid_x`, ' \
                   't1.`isMale` AS `isMale_x`, t1.`scale` AS `scale_x`, ' \
                   't1.`birth` AS `birth_x`, t1.`ds` AS `ds_x`, t2.`id` AS `id_y`, ' \
                   't2.`fid` AS `fid_y`, t2.`isMale` AS `isMale_y`, t2.`scale` AS `scale_y`, ' \
                   't2.`birth` AS `birth_y`, t2.`ds` AS `ds_y` \n' \
                   'FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   'INNER JOIN \n' \
                   '  mocked_project.`pyodps_test_expr_table2` t2\n' \
                   'ON t1.`name` == t2.`name`'
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(joined, prettify=False)))

    def testUnionPrune(self):
        left = self.expr.select('name', 'id')
        right = self.expr3.select(self.expr3.fid.astype('int').rename('id'), self.expr3.name)
        expr = left.union(right)['id']

        expected = "SELECT t3.`id` \n" \
                   "FROM (\n" \
                   "  SELECT t1.`id` \n" \
                   "  FROM mocked_project.`pyodps_test_expr_table` t1 \n" \
                   "  UNION ALL\n" \
                   "    SELECT CAST(t2.`fid` AS BIGINT) AS `id` \n" \
                   "    FROM mocked_project.`pyodps_test_expr_table2` t2\n" \
                   ") t3"
        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))

        expr = self.expr.union(self.expr2)

        expected = 'SELECT * \n' \
                   'FROM (\n' \
                   '  SELECT * \n' \
                   '  FROM mocked_project.`pyodps_test_expr_table` t1 \n' \
                   '  UNION ALL\n' \
                   '    SELECT * \n' \
                   '    FROM mocked_project.`pyodps_test_expr_table2` t2\n' \
                   ') t3'

        self.assertEqual(to_str(expected), to_str(ODPSEngine(self.odps).compile(expr, prettify=False)))
Ejemplo n.º 3
0
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(['name', 'id'],
                                   datatypes('string', 'int64'))
        table = MockTable(name='pyodps_test_expr_table', schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        schema2 = Schema.from_lists(['name2', 'id2'],
                                    datatypes('string', 'int64'))
        table2 = MockTable(name='pyodps_test_expr_table2', schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema2)

    def _lines_eq(self, expected, actual):
        self.assertSequenceEqual(
            [to_str(line.rstrip()) for line in expected.split('\n')],
            [to_str(line.rstrip()) for line in actual.split('\n')])

    def testProjectionFormatter(self):
        expr = self.expr['name', self.expr.id.rename('new_id')].new_id.astype(
            'float32')
        self._lines_eq(EXPECTED_PROJECTION_FORMAT, repr(expr))

    def testFilterFormatter(self):
        expr = self.expr[(self.expr.name != 'test') & (self.expr.id > 100)]
        self._lines_eq(EXPECTED_FILTER_FORMAT, repr(expr))

    def testSliceFormatter(self):
        expr = self.expr[:100]
        self._lines_eq(EXPECTED_SLICE_FORMAT, repr(expr))

        expr = self.expr[5:100:3]
        self._lines_eq(EXPECTED_SLICE_WITH_START_STEP_FORMAT, repr(expr))

    def testArithmeticFormatter(self):
        expr = self.expr
        d = -(expr['id']) + 20.34 - expr['id'] + float(20) * expr['id'] \
            - expr['id'] / 4.9 + 40 // 2 + expr['id'] // 1.2

        try:
            self._lines_eq(EXPECTED_ARITHMETIC_FORMAT, repr(d))
        except AssertionError as e:
            left = [
                to_str(line.rstrip())
                for line in EXPECTED_ARITHMETIC_FORMAT.split('\n')
            ]
            right = [to_str(line.rstrip()) for line in repr(d).split('\n')]
            self.assertEqual(len(left), len(right))
            for l, r in zip(left, right):
                try:
                    self.assertEqual(l, r)
                except AssertionError:
                    try:
                        self.assertAlmostEqual(float(l), float(r))
                    except:
                        raise e

    def testSortFormatter(self):
        expr = self.expr.sort(['name', -self.expr.id])

        self._lines_eq(EXPECTED_SORT_FORMAT, repr(expr))

    def testDistinctFormatter(self):
        expr = self.expr.distinct(['name', self.expr.id + 1])

        self._lines_eq(EXPECTED_DISTINCT_FORMAT, repr(expr))

    def testGroupbyFormatter(self):
        expr = self.expr.groupby(['name', 'id']).agg(new_id=self.expr.id.sum())

        self._lines_eq(EXPECTED_GROUPBY_FORMAT, repr(expr))

        grouped = self.expr.groupby(['name'])
        expr = grouped.mutate(grouped.row_number(sort='id'))

        self._lines_eq(EXPECTED_MUTATE_FORMAT, repr(expr))

        expr = self.expr.groupby(['name', 'id']).count()
        self._lines_eq(EXPECTED_GROUPBY_COUNT_FORMAT, repr(expr))

    def testReductionFormatter(self):
        expr = self.expr.groupby(['id']).id.std()

        self._lines_eq(EXPECTED_REDUCTION_FORMAT, repr(expr))

        expr = self.expr.id.mean()
        self._lines_eq(EXPECTED_REDUCTION_FORMAT2, repr(expr))

        expr = self.expr.count()
        self._lines_eq(EXPECTED_REDUCTION_FORMAT3, repr(expr))

    def testWindowFormatter(self):
        expr = self.expr.groupby(['name']).sort(-self.expr.id).name.rank()

        self._lines_eq(EXPECTED_WINDOW_FORMAT1, repr(expr))

        expr = self.expr.groupby(['id']).id.cummean(preceding=10,
                                                    following=5,
                                                    unique=True)

        self._lines_eq(EXPECTED_WINDOW_FORMAT2, repr(expr))

    def testElementFormatter(self):
        expr = self.expr.name.contains('test')

        self._lines_eq(EXPECTED_STRING_FORMAT, repr(expr))

        expr = self.expr.id.between(1, 3)

        self._lines_eq(EXPECTED_ELEMENT_FORMAT, repr(expr))

        expr = self.expr.name.astype('datetime').strftime('%Y')

        self._lines_eq(EXPECTED_DATETIME_FORMAT, repr(expr))

        expr = self.expr.id.switch(3,
                                   self.expr.name,
                                   4,
                                   self.expr.name + 'abc',
                                   default=self.expr.name + 'test')
        self._lines_eq(EXPECTED_SWITCH_FORMAT, repr(expr))

    def testJoinFormatter(self):
        expr = self.expr.join(self.expr2, ('name', 'name2'))
        self._lines_eq(EXPECTED_JOIN_FORMAT, repr(expr))

    def testAstypeFormatter(self):
        expr = self.expr.id.astype('float')
        self._lines_eq(EXPECTED_CAST_FORMAT, repr(expr))
class Test(TestBase):
    def setup(self):
        from odps.df.expr.tests.core import MockTable
        schema = Schema.from_lists(types._data_types.keys(), types._data_types.values())
        self.expr = CollectionExpr(_source_data=None, _schema=schema)
        self.sourced_expr = CollectionExpr(_source_data=MockTable(client=self.odps.rest), _schema=schema)

    def testSort(self):
        sorted_expr = self.expr.sort(self.expr.int64)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True])

        sorted_expr = self.expr.sort_values([self.expr.float32, 'string'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True] * 2)

        sorted_expr = self.expr.sort([self.expr.decimal, 'boolean', 'string'], ascending=False)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False] * 3)

        sorted_expr = self.expr.sort([self.expr.int8, 'datetime', 'float64'],
                                     ascending=[False, True, False])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, False])

        sorted_expr = self.expr.sort([-self.expr.int8, 'datetime', 'float64'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, True])

    def testDistinct(self):
        distinct = self.expr.distinct()
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr._schema)

        distinct = self.expr.distinct(self.expr.string)
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[[self.expr.string]]._schema)

        distinct = self.expr.distinct([self.expr.boolean, 'decimal'])
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[['boolean', 'decimal']]._schema)

        self.assertRaises(ExpressionError, lambda: self.expr['boolean', self.expr.string.unique()])

    def testMapReduce(self):
        @output(['id', 'name', 'rating'], ['int', 'string', 'int'])
        def mapper(row):
            yield row.int64, row.string, row.int32

        @output(['name', 'rating'], ['string', 'int'])
        def reducer(_):
            i = [0]
            def h(row):
                if i[0] <= 1:
                    yield row.name, row.rating
            return h

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort='rating', ascending=False)
        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 2)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort=['rating', 'id'], ascending=[False, True])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertTrue(expr._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort=['rating', 'id'], ascending=False)

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertFalse(expr._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper, reducer, group='name',
                                    sort=['name', 'rating', 'id'],
                                    ascending=[False, True, False])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertFalse(expr._sort_fields[0]._ascending)
        self.assertTrue(expr._sort_fields[1]._ascending)
        self.assertFalse(expr._sort_fields[2]._ascending)

    def testSample(self):
        self.assertIsInstance(self.expr.sample(100), SampledCollectionExpr)
        self.assertIsInstance(self.expr.sample(parts=10), SampledCollectionExpr)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(self.expr.sample(frac=0.5), DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(self.expr.sample(frac=0.5), SampledCollectionExpr)
        self.assertIsInstance(self.sourced_expr.sample(frac=0.5), DataFrame)

        self.assertRaises(ExpressionError, lambda: self.expr.sample())
        self.assertRaises(ExpressionError, lambda: self.expr.sample(i=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=100, frac=0.5))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=100, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=100, frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=1.5))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(parts=10, i=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(parts=10, i=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(parts=10, n=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(weights='weights', strata='strata'))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac='Yes:10', strata='strata'))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=set(), strata='strata'))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(n=set(), strata='strata'))

    def testPivot(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot('string', 'int8', 'float32')

        self.assertIn('string', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 1)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        expr = self.expr.pivot(
            ['string', 'int8'], 'int16', ['datetime', 'string'])

        self.assertIn('string', expr._schema._name_indexes)
        self.assertIn('int8', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 2)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        self.assertRaises(ValueError, lambda: self.expr.pivot(
            ['string', 'int8'], ['datetime', 'string'], 'int16'))

    def testPivotTable(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot_table(values='int8', rows='float32')
        self.assertNotIsInstance(expr, DynamicMixin)
        self.assertEqual(expr.schema.names, ['float32', 'int8_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'), rows=['float32', 'int8'])
        self.assertEqual(expr.schema.names, ['float32', 'int8', 'int16_mean', 'int32_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'), rows=['string', 'boolean'],
                                     aggfunc=['mean', 'sum'])
        self.assertEqual(expr.schema.names, ['string', 'boolean', 'int16_mean', 'int32_mean',
                                             'int16_sum', 'int32_sum'])
        self.assertEqual(expr.schema.types, [types.string, types.boolean, types.float64, types.float64,
                                             types.int16, types.int32])

        @output(['my_mean'], ['float'])
        class Aggregator(object):
            def buffer(self):
                return [0.0, 0]

            def __call__(self, buffer, val):
                buffer[0] += val
                buffer[1] += 1

            def merge(self, buffer, pbuffer):
                buffer[0] += pbuffer[0]
                buffer[1] += pbuffer[1]

            def getvalue(self, buffer):
                if buffer[1] == 0:
                    return 0.0
                return buffer[0] / buffer[1]

        expr = self.expr.pivot_table(values='int16', rows='string', aggfunc=Aggregator)
        self.assertEqual(expr.schema.names, ['string', 'int16_my_mean'])
        self.assertEqual(expr.schema.types, [types.string, types.float64])

        aggfunc = OrderedDict([('my_agg', Aggregator), ('my_agg2', Aggregator)])

        expr = self.expr.pivot_table(values='int16', rows='string', aggfunc=aggfunc)
        self.assertEqual(expr.schema.names, ['string', 'int16_my_agg', 'int16_my_agg2'])
        self.assertEqual(expr.schema.types, [types.string, types.float64, types.float64])

        expr = self.expr.pivot_table(values='int16', columns='boolean', rows='string')
        self.assertIsInstance(expr, DynamicMixin)

    def testScaleValues(self):
        expr = self.expr.min_max_scale()
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names)

        expr = self.expr.min_max_scale(preserve=True)
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names +
                             [n + '_scaled' for n in self.expr.dtypes.names
                              if n.startswith('int') or n.startswith('float')])

        expr = self.expr.std_scale()
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names)

        expr = self.expr.std_scale(preserve=True)
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names +
                             [n + '_scaled' for n in self.expr.dtypes.names
                              if n.startswith('int') or n.startswith('float')])

    def testAppendId(self):
        expr = self.expr.append_id(id_col='id_col')
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(expr, DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(expr, AppendIDCollectionExpr)
        self.assertIn('id_col', expr.schema)

        self.assertIsInstance(self.sourced_expr.append_id(), DataFrame)

    def testSplit(self):
        expr1, expr2 = self.expr.split(0.6)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(expr1, DataFrame)
            self.assertIsInstance(expr2, DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(expr1, SplitCollectionExpr)
            self.assertIsInstance(expr2, SplitCollectionExpr)
            self.assertTupleEqual((expr1._split_id, expr2._split_id), (0, 1))
Ejemplo n.º 5
0
class Test(TestBase):
    def setup(self):
        from odps.df.expr.tests.core import MockTable
        schema = Schema.from_lists(types._data_types.keys(),
                                   types._data_types.values())
        self.expr = CollectionExpr(_source_data=None, _schema=schema)
        self.sourced_expr = CollectionExpr(
            _source_data=MockTable(client=self.odps.rest), _schema=schema)

    def testSort(self):
        sorted_expr = self.expr.sort(self.expr.int64)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True])

        sorted_expr = self.expr.sort_values([self.expr.float32, 'string'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True] * 2)

        sorted_expr = self.expr.sort([self.expr.decimal, 'boolean', 'string'],
                                     ascending=False)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False] * 3)

        sorted_expr = self.expr.sort([self.expr.int8, 'datetime', 'float64'],
                                     ascending=[False, True, False])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, False])

        sorted_expr = self.expr.sort([-self.expr.int8, 'datetime', 'float64'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, True])

    def testDistinct(self):
        distinct = self.expr.distinct()
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr._schema)

        distinct = self.expr.distinct(self.expr.string)
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema,
                         self.expr[[self.expr.string]]._schema)

        distinct = self.expr.distinct([self.expr.boolean, 'decimal'])
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[['boolean',
                                                      'decimal']]._schema)

        self.assertRaises(
            ExpressionError,
            lambda: self.expr['boolean', self.expr.string.unique()])

    def testMapReduce(self):
        @output(['id', 'name', 'rating'], ['int', 'string', 'int'])
        def mapper(row):
            yield row.int64, row.string, row.int32

        @output(['name', 'rating'], ['string', 'int'])
        def reducer(_):
            i = [0]

            def h(row):
                if i[0] <= 1:
                    yield row.name, row.rating

            return h

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort='rating',
                                    ascending=False)
        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 2)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['rating', 'id'],
                                    ascending=[False, True])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertTrue(expr._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['rating', 'id'],
                                    ascending=False)

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertFalse(expr._sort_fields[2]._ascending)

    def testSample(self):
        self.assertIsInstance(self.expr.sample(100), SampledCollectionExpr)
        self.assertIsInstance(self.expr.sample(parts=10),
                              SampledCollectionExpr)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(self.expr.sample(frac=0.5), DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(self.expr.sample(frac=0.5),
                                  SampledCollectionExpr)
        self.assertIsInstance(self.sourced_expr.sample(frac=0.5), DataFrame)

        self.assertRaises(ExpressionError, lambda: self.expr.sample())
        self.assertRaises(ExpressionError, lambda: self.expr.sample(i=-1))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, frac=0.5))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, parts=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(frac=0.5, parts=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=1.5))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, i=-1))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, i=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, n=10))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(weights='weights', strata='strata'))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(frac='Yes:10', strata='strata'))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(frac=set(), strata='strata'))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=set(), strata='strata'))
Ejemplo n.º 6
0
class Test(TestBase):
    def setup(self):
        from odps.df.expr.tests.core import MockTable
        schema = Schema.from_lists(types._data_types.keys(),
                                   types._data_types.values())
        self.expr = CollectionExpr(_source_data=None, _schema=schema)
        self.sourced_expr = CollectionExpr(
            _source_data=MockTable(client=self.odps.rest), _schema=schema)

    def testSort(self):
        sorted_expr = self.expr.sort(self.expr.int64)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True])

        sorted_expr = self.expr.sort_values([self.expr.float32, 'string'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True] * 2)

        sorted_expr = self.expr.sort([self.expr.decimal, 'boolean', 'string'],
                                     ascending=False)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False] * 3)

        sorted_expr = self.expr.sort([self.expr.int8, 'datetime', 'float64'],
                                     ascending=[False, True, False])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, False])

        sorted_expr = self.expr.sort([-self.expr.int8, 'datetime', 'float64'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, True])

    def testDistinct(self):
        distinct = self.expr.distinct()
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr._schema)

        distinct = self.expr.distinct(self.expr.string)
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema,
                         self.expr[[self.expr.string]]._schema)

        distinct = self.expr.distinct([self.expr.boolean, 'decimal'])
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[['boolean',
                                                      'decimal']]._schema)

        self.assertRaises(
            ExpressionError,
            lambda: self.expr['boolean', self.expr.string.unique()])

    def testMapReduce(self):
        @output(['id', 'name', 'rating'], ['int', 'string', 'int'])
        def mapper(row):
            yield row.int64, row.string, row.int32

        @output(['name', 'rating'], ['string', 'int'])
        def reducer(_):
            i = [0]

            def h(row):
                if i[0] <= 1:
                    yield row.name, row.rating

            return h

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort='rating',
                                    ascending=False)
        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr.input._sort_fields), 2)
        self.assertTrue(expr.input._sort_fields[0]._ascending)
        self.assertFalse(expr.input._sort_fields[1]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['rating', 'id'],
                                    ascending=[False, True])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr.input._sort_fields), 3)
        self.assertTrue(expr.input._sort_fields[0]._ascending)
        self.assertFalse(expr.input._sort_fields[1]._ascending)
        self.assertTrue(expr.input._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['rating', 'id'],
                                    ascending=False)

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr.input._sort_fields), 3)
        self.assertTrue(expr.input._sort_fields[0]._ascending)
        self.assertFalse(expr.input._sort_fields[1]._ascending)
        self.assertFalse(expr.input._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['name', 'rating', 'id'],
                                    ascending=[False, True, False])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr.input._sort_fields), 3)
        self.assertFalse(expr.input._sort_fields[0]._ascending)
        self.assertTrue(expr.input._sort_fields[1]._ascending)
        self.assertFalse(expr.input._sort_fields[2]._ascending)

    def testSample(self):
        from odps.ml.expr import AlgoCollectionExpr

        self.assertIsInstance(self.expr.sample(100), SampledCollectionExpr)
        self.assertIsInstance(self.expr.sample(parts=10),
                              SampledCollectionExpr)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(self.expr.sample(frac=0.5),
                                  AlgoCollectionExpr)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(self.expr.sample(frac=0.5),
                                  SampledCollectionExpr)
        self.assertIsInstance(self.sourced_expr.sample(frac=0.5),
                              AlgoCollectionExpr)

        self.assertRaises(ExpressionError, lambda: self.expr.sample())
        self.assertRaises(ExpressionError, lambda: self.expr.sample(i=-1))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, frac=0.5))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, parts=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(frac=0.5, parts=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=1.5))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, i=-1))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, i=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, n=10))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(weights='weights', strata='strata'))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(frac='Yes:10', strata='strata'))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(frac=set(), strata='strata'))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=set(), strata='strata'))

    def testPivot(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot('string', 'int8', 'float32')

        self.assertIn('string', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 1)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        expr = self.expr.pivot(['string', 'int8'], 'int16',
                               ['datetime', 'string'])

        self.assertIn('string', expr._schema._name_indexes)
        self.assertIn('int8', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 2)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        self.assertRaises(
            ValueError, lambda: self.expr.pivot(
                ['string', 'int8'], ['datetime', 'string'], 'int16'))

    def testPivotTable(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot_table(values='int8', rows='float32')
        self.assertNotIsInstance(expr, DynamicMixin)
        self.assertEqual(expr.schema.names, ['float32', 'int8_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'),
                                     rows=['float32', 'int8'])
        self.assertEqual(expr.schema.names,
                         ['float32', 'int8', 'int16_mean', 'int32_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'),
                                     rows=['string', 'boolean'],
                                     aggfunc=['mean', 'sum'])
        self.assertEqual(expr.schema.names, [
            'string', 'boolean', 'int16_mean', 'int32_mean', 'int16_sum',
            'int32_sum'
        ])
        self.assertEqual(expr.schema.types, [
            types.string, types.boolean, types.float64, types.float64,
            types.int16, types.int32
        ])

        @output(['my_mean'], ['float'])
        class Aggregator(object):
            def buffer(self):
                return [0.0, 0]

            def __call__(self, buffer, val):
                buffer[0] += val
                buffer[1] += 1

            def merge(self, buffer, pbuffer):
                buffer[0] += pbuffer[0]
                buffer[1] += pbuffer[1]

            def getvalue(self, buffer):
                if buffer[1] == 0:
                    return 0.0
                return buffer[0] / buffer[1]

        expr = self.expr.pivot_table(values='int16',
                                     rows='string',
                                     aggfunc=Aggregator)
        self.assertEqual(expr.schema.names, ['string', 'int16_my_mean'])
        self.assertEqual(expr.schema.types, [types.string, types.float64])

        aggfunc = OrderedDict([('my_agg', Aggregator),
                               ('my_agg2', Aggregator)])

        expr = self.expr.pivot_table(values='int16',
                                     rows='string',
                                     aggfunc=aggfunc)
        self.assertEqual(expr.schema.names,
                         ['string', 'int16_my_agg', 'int16_my_agg2'])
        self.assertEqual(expr.schema.types,
                         [types.string, types.float64, types.float64])

        expr = self.expr.pivot_table(values='int16',
                                     columns='boolean',
                                     rows='string')
        self.assertIsInstance(expr, DynamicMixin)

    def testScaleValue(self):
        expr = self.expr.min_max_scale()
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names)

        expr = self.expr.min_max_scale(preserve=True)
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(
            expr.dtypes.names, self.expr.dtypes.names + [
                n + '_scaled' for n in self.expr.dtypes.names
                if n.startswith('int') or n.startswith('float')
            ])

        expr = self.expr.std_scale()
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(expr.dtypes.names, self.expr.dtypes.names)

        expr = self.expr.std_scale(preserve=True)
        self.assertIsInstance(expr, CollectionExpr)
        self.assertListEqual(
            expr.dtypes.names, self.expr.dtypes.names + [
                n + '_scaled' for n in self.expr.dtypes.names
                if n.startswith('int') or n.startswith('float')
            ])

    def testApplyMap(self):
        from odps.df.expr.collections import ProjectCollectionExpr, Column
        from odps.df.expr.element import MappedExpr

        schema = Schema.from_lists(['idx', 'f1', 'f2', 'f3'],
                                   [types.int64] + [types.float64] * 3)
        expr = CollectionExpr(_source_data=None, _schema=schema)

        self.assertRaises(
            ValueError, lambda: expr.applymap(
                lambda v: v + 1, columns='idx', excludes='f1'))

        mapped = expr.applymap(lambda v: v + 1)
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c, MappedExpr)

        mapped = expr.applymap(lambda v: v + 1, columns='f1')
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c, MappedExpr if c.name == 'f1' else Column)

        map_cols = set(['f1', 'f2', 'f3'])
        mapped = expr.applymap(lambda v: v + 1, columns=map_cols)
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c,
                                  MappedExpr if c.name in map_cols else Column)

        mapped = expr.applymap(lambda v: v + 1, excludes='idx')
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c, Column if c.name == 'idx' else MappedExpr)

        exc_cols = set(['idx', 'f1'])
        mapped = expr.applymap(lambda v: v + 1, excludes=exc_cols)
        self.assertIsInstance(mapped, ProjectCollectionExpr)
        for c in mapped._fields:
            self.assertIsInstance(c,
                                  Column if c.name in exc_cols else MappedExpr)

    def testCallableColumn(self):
        from odps.df.expr.expressions import CallableColumn
        from odps.df.expr.collections import ProjectCollectionExpr

        schema = Schema.from_lists(['name', 'f1', 'append_id'],
                                   [types.string, types.float64, types.int64])
        expr = CollectionExpr(_source_data=None, _schema=schema)
        self.assertIsInstance(expr.append_id, CallableColumn)
        self.assertNotIsInstance(expr.f1, CallableColumn)

        projected = expr[expr.name, expr.append_id]
        self.assertIsInstance(projected, ProjectCollectionExpr)
        self.assertListEqual(projected.schema.names, ['name', 'append_id'])

        projected = expr[expr.name, expr.f1]
        self.assertNotIsInstance(projected.append_id, CallableColumn)

        appended = expr.append_id(id_col='id_col')
        self.assertIn('id_col', appended.schema)

    def testAppendId(self):
        from odps.ml.expr import AlgoCollectionExpr

        expr = self.expr.append_id(id_col='id_col')
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(expr, AlgoCollectionExpr)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(expr, AppendIDCollectionExpr)
        self.assertIn('id_col', expr.schema)

        self.assertIsInstance(self.sourced_expr.append_id(),
                              AlgoCollectionExpr)

    def testSplit(self):
        from odps.ml.expr import AlgoCollectionExpr

        expr1, expr2 = self.expr.split(0.6)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(expr1, AlgoCollectionExpr)
            self.assertIsInstance(expr2, AlgoCollectionExpr)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(expr1, SplitCollectionExpr)
            self.assertIsInstance(expr2, SplitCollectionExpr)
            self.assertTupleEqual((expr1._split_id, expr2._split_id), (0, 1))
Ejemplo n.º 7
0
class Test(TestBase):
    def setup(self):
        from odps.df.expr.tests.core import MockTable
        schema = Schema.from_lists(types._data_types.keys(),
                                   types._data_types.values())
        self.expr = CollectionExpr(_source_data=None, _schema=schema)
        self.sourced_expr = CollectionExpr(
            _source_data=MockTable(client=self.odps.rest), _schema=schema)

    def testSort(self):
        sorted_expr = self.expr.sort(self.expr.int64)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True])

        sorted_expr = self.expr.sort_values([self.expr.float32, 'string'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [True] * 2)

        sorted_expr = self.expr.sort([self.expr.decimal, 'boolean', 'string'],
                                     ascending=False)
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False] * 3)

        sorted_expr = self.expr.sort([self.expr.int8, 'datetime', 'float64'],
                                     ascending=[False, True, False])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, False])

        sorted_expr = self.expr.sort([-self.expr.int8, 'datetime', 'float64'])
        self.assertIsInstance(sorted_expr, CollectionExpr)
        self.assertEqual(sorted_expr._schema, self.expr._schema)
        self.assertSequenceEqual(sorted_expr._ascending, [False, True, True])

    def testDistinct(self):
        distinct = self.expr.distinct()
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr._schema)

        distinct = self.expr.distinct(self.expr.string)
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema,
                         self.expr[[self.expr.string]]._schema)

        distinct = self.expr.distinct([self.expr.boolean, 'decimal'])
        self.assertIsInstance(distinct, CollectionExpr)
        self.assertEqual(distinct._schema, self.expr[['boolean',
                                                      'decimal']]._schema)

        self.assertRaises(
            ExpressionError,
            lambda: self.expr['boolean', self.expr.string.unique()])

    def testMapReduce(self):
        @output(['id', 'name', 'rating'], ['int', 'string', 'int'])
        def mapper(row):
            yield row.int64, row.string, row.int32

        @output(['name', 'rating'], ['string', 'int'])
        def reducer(_):
            i = [0]

            def h(row):
                if i[0] <= 1:
                    yield row.name, row.rating

            return h

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort='rating',
                                    ascending=False)
        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 2)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['rating', 'id'],
                                    ascending=[False, True])

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertTrue(expr._sort_fields[2]._ascending)

        expr = self.expr.map_reduce(mapper,
                                    reducer,
                                    group='name',
                                    sort=['rating', 'id'],
                                    ascending=False)

        self.assertEqual(expr.schema.names, ['name', 'rating'])
        self.assertEqual(len(expr._sort_fields), 3)
        self.assertTrue(expr._sort_fields[0]._ascending)
        self.assertFalse(expr._sort_fields[1]._ascending)
        self.assertFalse(expr._sort_fields[2]._ascending)

    def testSample(self):
        self.assertIsInstance(self.expr.sample(100), SampledCollectionExpr)
        self.assertIsInstance(self.expr.sample(parts=10),
                              SampledCollectionExpr)
        try:
            import pandas
        except ImportError:
            # No pandas: go for XFlow
            self.assertIsInstance(self.expr.sample(frac=0.5), DataFrame)
        else:
            # Otherwise: go for Pandas
            self.assertIsInstance(self.expr.sample(frac=0.5),
                                  SampledCollectionExpr)
        self.assertIsInstance(self.sourced_expr.sample(frac=0.5), DataFrame)

        self.assertRaises(ExpressionError, lambda: self.expr.sample())
        self.assertRaises(ExpressionError, lambda: self.expr.sample(i=-1))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, frac=0.5))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, parts=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(frac=0.5, parts=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=100, frac=0.5, parts=10))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=-1))
        self.assertRaises(ExpressionError, lambda: self.expr.sample(frac=1.5))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, i=-1))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, i=10))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(parts=10, n=10))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(weights='weights', strata='strata'))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(frac='Yes:10', strata='strata'))
        self.assertRaises(
            ExpressionError,
            lambda: self.expr.sample(frac=set(), strata='strata'))
        self.assertRaises(ExpressionError,
                          lambda: self.expr.sample(n=set(), strata='strata'))

    def testPivot(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot('string', 'int8', 'float32')

        self.assertIn('string', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 1)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        expr = self.expr.pivot(['string', 'int8'], 'int16',
                               ['datetime', 'string'])

        self.assertIn('string', expr._schema._name_indexes)
        self.assertIn('int8', expr._schema._name_indexes)
        self.assertEqual(len(expr._schema._name_indexes), 2)
        self.assertIn('non_exist', expr._schema)
        self.assertIsInstance(expr['non_exist'], DynamicMixin)

        self.assertRaises(
            ValueError, lambda: self.expr.pivot(
                ['string', 'int8'], ['datetime', 'string'], 'int16'))

    def testPivotTable(self):
        from odps.df.expr.dynamic import DynamicMixin

        expr = self.expr.pivot_table(values='int8', rows='float32')
        self.assertNotIsInstance(expr, DynamicMixin)
        self.assertEqual(expr.schema.names, ['float32', 'int8_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'),
                                     rows=['float32', 'int8'])
        self.assertEqual(expr.schema.names,
                         ['float32', 'int8', 'int16_mean', 'int32_mean'])

        expr = self.expr.pivot_table(values=('int16', 'int32'),
                                     rows=['string', 'boolean'],
                                     aggfunc=['mean', 'sum'])
        self.assertEqual(expr.schema.names, [
            'string', 'boolean', 'int16_mean', 'int32_mean', 'int16_sum',
            'int32_sum'
        ])
        self.assertEqual(expr.schema.types, [
            types.string, types.boolean, types.float64, types.float64,
            types.int16, types.int32
        ])

        @output(['my_mean'], ['float'])
        class Aggregator(object):
            def buffer(self):
                return [0.0, 0]

            def __call__(self, buffer, val):
                buffer[0] += val
                buffer[1] += 1

            def merge(self, buffer, pbuffer):
                buffer[0] += pbuffer[0]
                buffer[1] += pbuffer[1]

            def getvalue(self, buffer):
                if buffer[1] == 0:
                    return 0.0
                return buffer[0] / buffer[1]

        expr = self.expr.pivot_table(values='int16',
                                     rows='string',
                                     aggfunc=Aggregator)
        self.assertEqual(expr.schema.names, ['string', 'int16_my_mean'])
        self.assertEqual(expr.schema.types, [types.string, types.float64])

        aggfunc = OrderedDict([('my_agg', Aggregator),
                               ('my_agg2', Aggregator)])

        expr = self.expr.pivot_table(values='int16',
                                     rows='string',
                                     aggfunc=aggfunc)
        self.assertEqual(expr.schema.names,
                         ['string', 'int16_my_agg', 'int16_my_agg2'])
        self.assertEqual(expr.schema.types,
                         [types.string, types.float64, types.float64])

        expr = self.expr.pivot_table(values='int16',
                                     columns='boolean',
                                     rows='string')
        self.assertIsInstance(expr, DynamicMixin)
class Test(TestBase):
    def setup(self):
        datatypes = lambda *types: [validate_data_type(t) for t in types]
        schema = Schema.from_lists(["name", "id"], datatypes("string", "int64"))
        table = MockTable(name="pyodps_test_expr_table", schema=schema)

        self.expr = CollectionExpr(_source_data=table, _schema=schema)

        schema2 = Schema.from_lists(["name2", "id2"], datatypes("string", "int64"))
        table2 = MockTable(name="pyodps_test_expr_table2", schema=schema2)
        self.expr2 = CollectionExpr(_source_data=table2, _schema=schema2)

    def _lines_eq(self, expected, actual):
        self.assertSequenceEqual(
            [to_str(line.rstrip()) for line in expected.split("\n")],
            [to_str(line.rstrip()) for line in actual.split("\n")],
        )

    def testProjectionFormatter(self):
        expr = self.expr["name", self.expr.id.rename("new_id")].new_id.astype("float32")
        self._lines_eq(EXPECTED_PROJECTION_FORMAT, repr(expr))

    def testFilterFormatter(self):
        expr = self.expr[(self.expr.name != "test") & (self.expr.id > 100)]
        self._lines_eq(EXPECTED_FILTER_FORMAT, repr(expr))

    def testSliceFormatter(self):
        expr = self.expr[:100]
        self._lines_eq(EXPECTED_SLICE_FORMAT, repr(expr))

        expr = self.expr[5:100:3]
        self._lines_eq(EXPECTED_SLICE_WITH_START_STEP_FORMAT, repr(expr))

    def testArithmeticFormatter(self):
        expr = self.expr
        d = -(expr["id"]) + 20.34 - expr["id"] + float(20) * expr["id"] - expr["id"] / 4.9 + 40 // 2 + expr["id"] // 1.2

        try:
            self._lines_eq(EXPECTED_ARITHMETIC_FORMAT, repr(d))
        except AssertionError as e:
            left = [to_str(line.rstrip()) for line in EXPECTED_ARITHMETIC_FORMAT.split("\n")]
            right = [to_str(line.rstrip()) for line in repr(d).split("\n")]
            self.assertEqual(len(left), len(right))
            for l, r in zip(left, right):
                try:
                    self.assertEqual(l, r)
                except AssertionError:
                    try:
                        self.assertAlmostEqual(float(l), float(r))
                    except:
                        raise e

    def testSortFormatter(self):
        expr = self.expr.sort(["name", -self.expr.id])

        self._lines_eq(EXPECTED_SORT_FORMAT, repr(expr))

    def testDistinctFormatter(self):
        expr = self.expr.distinct(["name", self.expr.id + 1])

        self._lines_eq(EXPECTED_DISTINCT_FORMAT, repr(expr))

    def testGroupbyFormatter(self):
        expr = self.expr.groupby(["name", "id"]).agg(new_id=self.expr.id.sum())

        self._lines_eq(EXPECTED_GROUPBY_FORMAT, repr(expr))

        grouped = self.expr.groupby(["name"])
        expr = grouped.mutate(grouped.row_number())

        self._lines_eq(EXPECTED_MUTATE_FORMAT, repr(expr))

        expr = self.expr.groupby(["name", "id"]).count()
        self._lines_eq(EXPECTED_GROUPBY_COUNT_FORMAT, repr(expr))

    def testReductionFormatter(self):
        expr = self.expr.groupby(["id"]).id.std()

        self._lines_eq(EXPECTED_REDUCTION_FORMAT, repr(expr))

        expr = self.expr.id.mean()
        self._lines_eq(EXPECTED_REDUCTION_FORMAT2, repr(expr))

        expr = self.expr.count()
        self._lines_eq(EXPECTED_REDUCTION_FORMAT3, repr(expr))

    def testWindowFormatter(self):
        expr = self.expr.groupby(["name"]).sort(-self.expr.id).name.rank()

        self._lines_eq(EXPECTED_WINDOW_FORMAT1, repr(expr))

        expr = self.expr.groupby(["id"]).id.cummean(preceding=10, following=5, unique=True)

        self._lines_eq(EXPECTED_WINDOW_FORMAT2, repr(expr))

    def testElementFormatter(self):
        expr = self.expr.name.contains("test")

        self._lines_eq(EXPECTED_STRING_FORMAT, repr(expr))

        expr = self.expr.id.between(1, 3)

        self._lines_eq(EXPECTED_ELEMENT_FORMAT, repr(expr))

        expr = self.expr.name.astype("datetime").strftime("%Y")

        self._lines_eq(EXPECTED_DATETIME_FORMAT, repr(expr))

        expr = self.expr.id.switch(3, self.expr.name, 4, self.expr.name + "abc", default=self.expr.name + "test")
        self._lines_eq(EXPECTED_SWITCH_FORMAT, repr(expr))

    def testJoinFormatter(self):
        expr = self.expr.join(self.expr2, ("name", "name2"))
        self._lines_eq(EXPECTED_JOIN_FORMAT, repr(expr))

    def testAstypeFormatter(self):
        expr = self.expr.id.astype("float")
        self._lines_eq(EXPECTED_CAST_FORMAT, repr(expr))