def test_count_per_element(self):
     spark_operations = SparkRDDOperations()
     data = ['a', 'b', 'a']
     dist_data = SparkRDDOperationsTest.sc.parallelize(data)
     rdd = spark_operations.count_per_element(dist_data)
     result = rdd.collect()
     result = dict(result)
     self.assertDictEqual(result, {'a': 2, 'b': 1})
 def test_filter_by_key_none_public_partitions(self):
     spark_operations = SparkRDDOperations()
     data = [(1, 11, 111), (2, 22, 222)]
     dist_data = SparkRDDOperationsTest.sc.parallelize(data)
     public_partitions = None
     with self.assertRaises(TypeError):
         spark_operations.filter_by_key(
             dist_data, public_partitions,
             SparkRDDOperationsTest.data_extractors)
 def test_sample_fixed_per_key(self):
     spark_operations = SparkRDDOperations()
     data = [(1, 11), (2, 22), (3, 33), (1, 14), (2, 25), (1, 16)]
     dist_data = SparkRDDOperationsTest.sc.parallelize(data)
     rdd = spark_operations.sample_fixed_per_key(dist_data, 2)
     result = dict(rdd.collect())
     self.assertEqual(len(result[1]), 2)
     self.assertTrue(set(result[1]).issubset({11, 14, 16}))
     self.assertSetEqual(set(result[2]), {22, 25})
     self.assertSetEqual(set(result[3]), {33})
Beispiel #4
0
 def test_filter_by_key_nonempty_public_partitions(self, distributed):
     spark_operations = SparkRDDOperations()
     data = [(1, 11, 111), (2, 22, 222)]
     dist_data = SparkRDDOperationsTest.sc.parallelize(data)
     public_partitions = [11, 33]
     if distributed:
         public_partitions = SparkRDDOperationsTest.sc.parallelize(
             public_partitions)
     result = spark_operations.filter_by_key(
         dist_data, public_partitions,
         SparkRDDOperationsTest.data_extractors).collect()
     self.assertListEqual(result, [(11, (1, 11, 111))])
 def test_reduce_accumulators_per_key(self):
     spark_operations = SparkRDDOperations()
     data = [(1, 11), (2, 22), (3, 33), (1, 14), (2, 25), (1, 16)]
     dist_data = SparkRDDOperationsTest.sc.parallelize(data)
     rdd = spark_operations.map_values(dist_data, SumAccumulator,
                                       "Wrap into accumulators")
     result = spark_operations\
         .reduce_accumulators_per_key(rdd, "Reduce accumulator per key")\
         .map(lambda row: (row[0], row[1].get_metrics()))\
         .collect()
     result = dict(result)
     self.assertDictEqual(result, {1: 41, 2: 47, 3: 33})
Beispiel #6
0
 def setUpClass(cls):
     conf = pyspark.SparkConf()
     cls.sc = pyspark.SparkContext(conf=conf)
     cls.data_extractors = DataExtractors(
         partition_extractor=lambda x: x[1],
         privacy_id_extractor=lambda x: x[0],
         value_extractor=lambda x: x[2])
     cls.ops = SparkRDDOperations()
    def test_flat_map(self):
        spark_operations = SparkRDDOperations()
        data = [[1, 2, 3, 4], [5, 6, 7, 8]]
        dist_data = SparkRDDOperationsTest.sc.parallelize(data)
        self.assertEqual(
            spark_operations.flat_map(dist_data, lambda x: x).collect(),
            [1, 2, 3, 4, 5, 6, 7, 8])

        data = [("a", [1, 2, 3, 4]), ("b", [5, 6, 7, 8])]
        dist_data = SparkRDDOperationsTest.sc.parallelize(data)
        self.assertEqual(
            spark_operations.flat_map(dist_data, lambda x: x[1]).collect(),
            [1, 2, 3, 4, 5, 6, 7, 8])
        self.assertEqual(
            spark_operations.flat_map(dist_data,
                                      lambda x: [(x[0], y)
                                                 for y in x[1]]).collect(),
            [("a", 1), ("a", 2), ("a", 3), ("a", 4), ("b", 5), ("b", 6),
             ("b", 7), ("b", 8)])