예제 #1
0
class TestSeriesRegionMeanMethods(PySparkTestCase):
    def setUp(self):
        super(TestSeriesRegionMeanMethods, self).setUp()
        self.dataLocal = [
            ((0, 0), array([1.0, 2.0, 3.0])),
            ((0, 1), array([2.0, 2.0, 4.0])),
            ((1, 0), array([4.0, 2.0, 1.0])),
            ((1, 1), array([3.0, 1.0, 1.0]))
        ]
        self.series = Series(self.sc.parallelize(self.dataLocal),
                             dtype=self.dataLocal[0][1].dtype,
                             index=arange(3))

    def __setup_meanByRegion(self, useMask=False):
        itemIdxs = [1, 2]  # data keys for items 1 and 2 (0-based)
        keys = [self.dataLocal[idx][0] for idx in itemIdxs]

        expectedKeys = tuple(vstack(keys).mean(axis=0).astype('int16'))
        expected = vstack([self.dataLocal[idx][1] for idx in itemIdxs]).mean(axis=0)
        if useMask:
            keys = array([[0, 1], [1, 0]], dtype='uint8')
        return keys, expectedKeys, expected

    @staticmethod
    def __checkAsserts(expectedLen, expectedKeys, expected, actual):
        assert_equals(expectedLen, len(actual))
        assert_equals(expectedKeys, actual[0])
        assert_true(array_equal(expected, actual[1]))

    @staticmethod
    def __checkNestedAsserts(expectedLen, expectedKeys, expected, actual):
        assert_equals(expectedLen, len(actual))
        for i in xrange(expectedLen):
            assert_equals(expectedKeys[i], actual[i][0])
            assert_true(array_equal(expected[i], actual[i][1]))

    def __checkReturnedSeriesAttributes(self, newSeries):
        assert_true(newSeries._dims is None)  # check that new _dims is unset
        assert_equals(self.series.dtype, newSeries._dtype)  # check that new dtype is set
        assert_true(array_equal(self.series.index, newSeries._index))  # check that new index is set
        assert_is_not_none(newSeries.dims)  # check that new dims is at least calculable (expected to be meaningless)

    def __run_tst_meanOfRegion(self, useMask):
        keys, expectedKeys, expected = self.__setup_meanByRegion(useMask)
        actual = self.series.meanOfRegion(keys)
        TestSeriesRegionMeanMethods.__checkAsserts(2, expectedKeys, expected, actual)

    def test_meanOfRegion(self):
        self.__run_tst_meanOfRegion(False)

    def test_meanOfRegionWithMask(self):
        self.__run_tst_meanOfRegion(True)

    def test_meanOfRegionErrorsOnMissing(self):
        _, expectedKeys, expected = self.__setup_meanByRegion(False)
        keys = [(17, 24), (17, 25)]
        # if no records match, return None, None
        actualKey, actualVal = self.series.meanOfRegion(keys)
        assert_is_none(actualKey)
        assert_is_none(actualVal)
        # if we have only a partial match but haven't turned on validation, return a sensible value
        keys = [(0, 1), (17, 25)]
        actualKey, actualVal = self.series.meanOfRegion(keys)
        assert_equals((0, 1), actualKey)
        assert_true(array_equal(self.dataLocal[1][1], actualVal))
        # throw an error on a partial match when validation turned on
        assert_raises(ValueError, self.series.meanOfRegion, keys, validate=True)

    def test_meanByRegions_singleRegion(self):
        keys, expectedKeys, expected = self.__setup_meanByRegion()

        actualSeries = self.series.meanByRegions([keys])
        actual = actualSeries.collect()
        self.__checkReturnedSeriesAttributes(actualSeries)
        TestSeriesRegionMeanMethods.__checkNestedAsserts(1, [expectedKeys], [expected], actual)

    def test_meanByRegionsErrorsOnMissing(self):
        keys, expectedKeys, expected = self.__setup_meanByRegion()
        keys += [(17, 25)]

        # check that we get a sensible value with validation turned off:
        actualSeries = self.series.meanByRegions([keys])
        actual = actualSeries.collect()
        self.__checkReturnedSeriesAttributes(actualSeries)
        TestSeriesRegionMeanMethods.__checkNestedAsserts(1, [expectedKeys], [expected], actual)

        # throw an error on a partial match when validation turned on
        # this error will be on the workers, which propagates back to the driver
        # as something other than the ValueError that it started out life as
        assert_raises(Exception, self.series.meanByRegions([keys], validate=True).count)

    def test_meanByRegions_singleRegionWithMask(self):
        mask, expectedKeys, expected = self.__setup_meanByRegion(True)

        actualSeries = self.series.meanByRegions(mask)
        actual = actualSeries.collect()
        self.__checkReturnedSeriesAttributes(actualSeries)
        TestSeriesRegionMeanMethods.__checkNestedAsserts(1, [expectedKeys], [expected], actual)

    def test_meanByRegions_twoRegions(self):
        nestedKeys, expectedKeys, expected = [], [], []
        for itemIdxs in [(0, 1), (1, 2)]:
            keys = [self.dataLocal[idx][0] for idx in itemIdxs]
            nestedKeys.append(keys)
            avgKeys = tuple(vstack(keys).mean(axis=0).astype('int16'))
            expectedKeys.append(avgKeys)
            avgVals = vstack([self.dataLocal[idx][1] for idx in itemIdxs]).mean(axis=0)
            expected.append(avgVals)

        actualSeries = self.series.meanByRegions(nestedKeys)
        actual = actualSeries.collect()
        self.__checkReturnedSeriesAttributes(actualSeries)
        TestSeriesRegionMeanMethods.__checkNestedAsserts(2, expectedKeys, expected, actual)

    def test_meanByRegions_twoRegionsWithMask(self):
        expectedKeys, expected = [], []
        mask = array([[1, 1], [2, 0]], dtype='uint8')
        for itemIdxs in [(0, 1), (2, )]:
            keys = [self.dataLocal[idx][0] for idx in itemIdxs]
            avgKeys = tuple(vstack(keys).mean(axis=0).astype('int16'))
            expectedKeys.append(avgKeys)
            avgVals = vstack([self.dataLocal[idx][1] for idx in itemIdxs]).mean(axis=0)
            expected.append(avgVals)

        actualSeries = self.series.meanByRegions(mask)
        actual = actualSeries.collect()
        self.__checkReturnedSeriesAttributes(actualSeries)
        TestSeriesRegionMeanMethods.__checkNestedAsserts(2, expectedKeys, expected, actual)