Ejemplo n.º 1
0
def test_custom_aggregation():
    class MockReduction1(CustomReduction):
        def agg(self, v1):
            return v1.sum()

    class MockReduction2(CustomReduction):
        def pre(self, value):
            return value + 1, value**2

        def agg(self, v1, v2):
            return v1.sum(), v2.prod()

        def post(self, v1, v2):
            return v1 + v2

    for ndim in [1, 2]:
        compiler = ReductionCompiler()
        compiler.add_function(MockReduction1(), ndim=ndim)
        result = compiler.compile()
        # check agg_funcs
        assert len(result.agg_funcs) == 1
        assert result.agg_funcs[0].map_func_name == 'custom_reduction'
        assert result.agg_funcs[0].agg_func_name == 'custom_reduction'
        assert isinstance(result.agg_funcs[0].custom_reduction, MockReduction1)
        assert result.agg_funcs[0].output_limit == 1

        compiler = ReductionCompiler()
        compiler.add_function(MockReduction2(), ndim=ndim)
        result = compiler.compile()
        # check agg_funcs
        assert len(result.agg_funcs) == 1
        assert result.agg_funcs[0].map_func_name == 'custom_reduction'
        assert result.agg_funcs[0].agg_func_name == 'custom_reduction'
        assert isinstance(result.agg_funcs[0].custom_reduction, MockReduction2)
        assert result.agg_funcs[0].output_limit == 2
Ejemplo n.º 2
0
    def testCustomAggregation(self):
        class MockReduction1(CustomReduction):
            def agg(self, v1):
                return v1.sum()

        class MockReduction2(CustomReduction):
            def pre(self, value):
                return value + 1, value ** 2

            def agg(self, v1, v2):
                return v1.sum(), v2.prod()

            def post(self, v1, v2):
                return v1 + v2

        for ndim in [1, 2]:
            compiler = ReductionCompiler()
            compiler.add_function(MockReduction1(), ndim=ndim)
            result = compiler.compile()
            # check agg_funcs
            self.assertEqual(len(result.agg_funcs), 1)
            self.assertEqual(result.agg_funcs[0].map_func_name, 'custom_reduction')
            self.assertEqual(result.agg_funcs[0].agg_func_name, 'custom_reduction')
            self.assertIsInstance(result.agg_funcs[0].custom_reduction, MockReduction1)
            self.assertEqual(result.agg_funcs[0].output_limit, 1)

            compiler = ReductionCompiler()
            compiler.add_function(MockReduction2(), ndim=ndim)
            result = compiler.compile()
            # check agg_funcs
            self.assertEqual(len(result.agg_funcs), 1)
            self.assertEqual(result.agg_funcs[0].map_func_name, 'custom_reduction')
            self.assertEqual(result.agg_funcs[0].agg_func_name, 'custom_reduction')
            self.assertIsInstance(result.agg_funcs[0].custom_reduction, MockReduction2)
            self.assertEqual(result.agg_funcs[0].output_limit, 2)
Ejemplo n.º 3
0
def test_compile_function():
    compiler = ReductionCompiler()
    ms = md.Series([1, 2, 3])
    # no Mars objects inside closures
    with pytest.raises(ValueError):
        compiler.add_function(functools.partial(lambda x: (x + ms).sum()),
                              ndim=2)
    # function should return a Mars object
    with pytest.raises(ValueError):
        compiler.add_function(lambda x: x is not None, ndim=2)
    # function should perform some sort of reduction in dimensionality
    with pytest.raises(ValueError):
        compiler.add_function(lambda x: x, ndim=2)
    # function should only contain acceptable operands
    with pytest.raises(ValueError):
        compiler.add_function(lambda x: x.sort_values().max(), ndim=1)
    with pytest.raises(ValueError):
        compiler.add_function(lambda x: x.max().shift(1), ndim=2)

    # test agg for all data
    for ndim in [1, 2]:
        compiler = ReductionCompiler(store_source=True)
        compiler.add_function(lambda x: (x**2).count() + 1, ndim=ndim)
        result = compiler.compile()
        # check pre_funcs
        assert len(result.pre_funcs) == 1
        assert 'pow' in result.pre_funcs[0].func.__source__
        # check agg_funcs
        assert len(result.agg_funcs) == 1
        assert result.agg_funcs[0].map_func_name == 'count'
        assert result.agg_funcs[0].agg_func_name == 'sum'
        # check post_funcs
        assert len(result.post_funcs) == 1
        assert result.post_funcs[0].func_name == '<lambda_0>'
        assert 'add' in result.post_funcs[0].func.__source__

        compiler.add_function(lambda x: -x.prod()**2 + (1 + (x**2).count()),
                              ndim=ndim)
        result = compiler.compile()
        # check pre_funcs
        assert len(result.pre_funcs) == 2
        assert 'pow' in result.pre_funcs[0].func.__source__ \
            or 'pow' in result.pre_funcs[1].func.__source__
        assert 'pow' not in result.pre_funcs[0].func.__source__ \
            or 'pow' not in result.pre_funcs[1].func.__source__
        # check agg_funcs
        assert len(result.agg_funcs) == 2
        assert set(result.agg_funcs[i].map_func_name
                   for i in range(2)) == {'count', 'prod'}
        assert set(result.agg_funcs[i].agg_func_name
                   for i in range(2)) == {'sum', 'prod'}
        # check post_funcs
        assert len(result.post_funcs) == 2
        assert result.post_funcs[0].func_name == '<lambda_0>'
        assert 'add' in result.post_funcs[0].func.__source__
        assert 'add' in result.post_funcs[1].func.__source__

        compiler = ReductionCompiler(store_source=True)
        compiler.add_function(lambda x: where_function(x.all(), x.count(), 0),
                              ndim=ndim)
        result = compiler.compile()
        # check pre_funcs
        assert len(result.pre_funcs) == 1
        assert result.pre_funcs[0].input_key == result.pre_funcs[0].output_key
        # check agg_funcs
        assert len(result.agg_funcs) == 2
        assert set(result.agg_funcs[i].map_func_name
                   for i in range(2)) == {'all', 'count'}
        assert set(result.agg_funcs[i].agg_func_name
                   for i in range(2)) == {'sum', 'all'}
        # check post_funcs
        assert len(result.post_funcs) == 1
        if ndim == 1:
            assert 'np.where' in result.post_funcs[0].func.__source__
        else:
            assert 'np.where' not in result.post_funcs[0].func.__source__
            assert '.where' in result.post_funcs[0].func.__source__

    # test agg for specific columns
    compiler = ReductionCompiler(store_source=True)
    compiler.add_function(lambda x: 1 + x.sum(), ndim=2, cols=['a', 'b'])
    compiler.add_function(lambda x: -1 + x.sum(), ndim=2, cols=['b', 'c'])
    result = compiler.compile()
    # check pre_funcs
    assert len(result.pre_funcs) == 1
    assert set(result.pre_funcs[0].columns) == set('abc')
    # check agg_funcs
    assert len(result.agg_funcs) == 1
    assert result.agg_funcs[0].map_func_name == 'sum'
    assert result.agg_funcs[0].agg_func_name == 'sum'
    # check post_funcs
    assert len(result.post_funcs) == 2
    assert set(''.join(sorted(result.post_funcs[i].columns))
               for i in range(2)) == {'ab', 'bc'}
Ejemplo n.º 4
0
    def testCompileFunction(self):
        compiler = ReductionCompiler()
        ms = md.Series([1, 2, 3])
        # no Mars objects inside closures
        with self.assertRaises(ValueError):
            compiler.add_function(functools.partial(lambda x: (x + ms).sum()),
                                  ndim=2)
        # function should return a Mars object
        with self.assertRaises(ValueError):
            compiler.add_function(lambda x: x is not None, ndim=2)
        # function should perform some sort of reduction in dimensionality
        with self.assertRaises(ValueError):
            compiler.add_function(lambda x: x, ndim=2)
        # function should only contain acceptable operands
        with self.assertRaises(ValueError):
            compiler.add_function(lambda x: x.sort_values().max(), ndim=1)
        with self.assertRaises(ValueError):
            compiler.add_function(lambda x: x.max().shift(1), ndim=2)

        # test agg for all data
        for ndim in [1, 2]:
            compiler = ReductionCompiler(store_source=True)
            compiler.add_function(lambda x: (x**2).count() + 1, ndim=ndim)
            result = compiler.compile()
            # check pre_funcs
            self.assertEqual(len(result.pre_funcs), 1)
            self.assertIn('pow', result.pre_funcs[0].func.__source__)
            # check agg_funcs
            self.assertEqual(len(result.agg_funcs), 1)
            self.assertEqual(result.agg_funcs[0].map_func_name, 'count')
            self.assertEqual(result.agg_funcs[0].agg_func_name, 'sum')
            # check post_funcs
            self.assertEqual(len(result.post_funcs), 1)
            self.assertEqual(result.post_funcs[0].func_name, '<lambda_0>')
            self.assertIn('add', result.post_funcs[0].func.__source__)

            compiler.add_function(lambda x: -x.prod()**2 + (1 +
                                                            (x**2).count()),
                                  ndim=ndim)
            result = compiler.compile()
            # check pre_funcs
            self.assertEqual(len(result.pre_funcs), 2)
            self.assertTrue('pow' in result.pre_funcs[0].func.__source__
                            or 'pow' in result.pre_funcs[1].func.__source__)
            self.assertTrue(
                'pow' not in result.pre_funcs[0].func.__source__
                or 'pow' not in result.pre_funcs[1].func.__source__)
            # check agg_funcs
            self.assertEqual(len(result.agg_funcs), 2)
            self.assertSetEqual(
                set(result.agg_funcs[i].map_func_name for i in range(2)),
                {'count', 'prod'})
            self.assertSetEqual(
                set(result.agg_funcs[i].agg_func_name for i in range(2)),
                {'sum', 'prod'})
            # check post_funcs
            self.assertEqual(len(result.post_funcs), 2)
            self.assertEqual(result.post_funcs[0].func_name, '<lambda_0>')
            self.assertIn('add', result.post_funcs[0].func.__source__)
            self.assertIn('add', result.post_funcs[1].func.__source__)

            compiler = ReductionCompiler(store_source=True)
            compiler.add_function(
                lambda x: where_function(x.all(), x.count(), 0), ndim=ndim)
            result = compiler.compile()
            # check pre_funcs
            self.assertEqual(len(result.pre_funcs), 1)
            self.assertEqual(result.pre_funcs[0].input_key,
                             result.pre_funcs[0].output_key)
            # check agg_funcs
            self.assertEqual(len(result.agg_funcs), 2)
            self.assertSetEqual(
                set(result.agg_funcs[i].map_func_name for i in range(2)),
                {'all', 'count'})
            self.assertSetEqual(
                set(result.agg_funcs[i].agg_func_name for i in range(2)),
                {'sum', 'all'})
            # check post_funcs
            self.assertEqual(len(result.post_funcs), 1)
            if ndim == 1:
                self.assertIn('np.where', result.post_funcs[0].func.__source__)
            else:
                self.assertNotIn('np.where',
                                 result.post_funcs[0].func.__source__)
                self.assertIn('.where', result.post_funcs[0].func.__source__)

        # test agg for specific columns
        compiler = ReductionCompiler(store_source=True)
        compiler.add_function(lambda x: 1 + x.sum(), ndim=2, cols=['a', 'b'])
        compiler.add_function(lambda x: -1 + x.sum(), ndim=2, cols=['b', 'c'])
        result = compiler.compile()
        # check pre_funcs
        self.assertEqual(len(result.pre_funcs), 1)
        self.assertSetEqual(set(result.pre_funcs[0].columns), set('abc'))
        # check agg_funcs
        self.assertEqual(len(result.agg_funcs), 1)
        self.assertEqual(result.agg_funcs[0].map_func_name, 'sum')
        self.assertEqual(result.agg_funcs[0].agg_func_name, 'sum')
        # check post_funcs
        self.assertEqual(len(result.post_funcs), 2)
        self.assertSetEqual(
            set(''.join(sorted(result.post_funcs[i].columns))
                for i in range(2)), {'ab', 'bc'})