def test_count_per_element(self): spark_operations = SparkRDDOperations() data = ['a', 'b', 'a'] dist_data = SparkRDDOperationsTest.sc.parallelize(data) rdd = spark_operations.count_per_element(dist_data) result = rdd.collect() result = dict(result) self.assertDictEqual(result, {'a': 2, 'b': 1})
def test_filter_by_key_none_public_partitions(self): spark_operations = SparkRDDOperations() data = [(1, 11, 111), (2, 22, 222)] dist_data = SparkRDDOperationsTest.sc.parallelize(data) public_partitions = None with self.assertRaises(TypeError): spark_operations.filter_by_key( dist_data, public_partitions, SparkRDDOperationsTest.data_extractors)
def test_sample_fixed_per_key(self): spark_operations = SparkRDDOperations() data = [(1, 11), (2, 22), (3, 33), (1, 14), (2, 25), (1, 16)] dist_data = SparkRDDOperationsTest.sc.parallelize(data) rdd = spark_operations.sample_fixed_per_key(dist_data, 2) result = dict(rdd.collect()) self.assertEqual(len(result[1]), 2) self.assertTrue(set(result[1]).issubset({11, 14, 16})) self.assertSetEqual(set(result[2]), {22, 25}) self.assertSetEqual(set(result[3]), {33})
def test_filter_by_key_nonempty_public_partitions(self, distributed): spark_operations = SparkRDDOperations() data = [(1, 11, 111), (2, 22, 222)] dist_data = SparkRDDOperationsTest.sc.parallelize(data) public_partitions = [11, 33] if distributed: public_partitions = SparkRDDOperationsTest.sc.parallelize( public_partitions) result = spark_operations.filter_by_key( dist_data, public_partitions, SparkRDDOperationsTest.data_extractors).collect() self.assertListEqual(result, [(11, (1, 11, 111))])
def test_reduce_accumulators_per_key(self): spark_operations = SparkRDDOperations() data = [(1, 11), (2, 22), (3, 33), (1, 14), (2, 25), (1, 16)] dist_data = SparkRDDOperationsTest.sc.parallelize(data) rdd = spark_operations.map_values(dist_data, SumAccumulator, "Wrap into accumulators") result = spark_operations\ .reduce_accumulators_per_key(rdd, "Reduce accumulator per key")\ .map(lambda row: (row[0], row[1].get_metrics()))\ .collect() result = dict(result) self.assertDictEqual(result, {1: 41, 2: 47, 3: 33})
def setUpClass(cls): conf = pyspark.SparkConf() cls.sc = pyspark.SparkContext(conf=conf) cls.data_extractors = DataExtractors( partition_extractor=lambda x: x[1], privacy_id_extractor=lambda x: x[0], value_extractor=lambda x: x[2]) cls.ops = SparkRDDOperations()
def test_flat_map(self): spark_operations = SparkRDDOperations() data = [[1, 2, 3, 4], [5, 6, 7, 8]] dist_data = SparkRDDOperationsTest.sc.parallelize(data) self.assertEqual( spark_operations.flat_map(dist_data, lambda x: x).collect(), [1, 2, 3, 4, 5, 6, 7, 8]) data = [("a", [1, 2, 3, 4]), ("b", [5, 6, 7, 8])] dist_data = SparkRDDOperationsTest.sc.parallelize(data) self.assertEqual( spark_operations.flat_map(dist_data, lambda x: x[1]).collect(), [1, 2, 3, 4, 5, 6, 7, 8]) self.assertEqual( spark_operations.flat_map(dist_data, lambda x: [(x[0], y) for y in x[1]]).collect(), [("a", 1), ("a", 2), ("a", 3), ("a", 4), ("b", 5), ("b", 6), ("b", 7), ("b", 8)])