def test_transform_with_repartition(self):
        # shards of pandas dataframe
        file_path = os.path.join(self.resource_path, "orca/data/csv")
        data_shard = zoo.orca.data.pandas.read_csv(file_path)
        partitions = data_shard.rdd.glom().collect()
        for par in partitions:
            assert len(par) <= 1

        def negative(df, column_name):
            df[column_name] = df[column_name] * (-1)
            return df

        shard2 = data_shard.transform_shard(negative, "sale_price")

        shard3 = shard2.repartition(4)
        partitions3 = shard3.rdd.glom().collect()
        for par in partitions3:
            assert len(par) <= 1

        shard4 = shard2.repartition(1)
        partitions4 = shard4.rdd.glom().collect()
        for par in partitions4:
            assert len(par) <= 1

        shard5 = shard4.transform_shard(negative, "sale_price")
        partitions5 = shard5.rdd.glom().collect()
        for par in partitions5:
            assert len(par) <= 1
        # shards of list
        data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
        sc = init_nncontext()
        rdd = sc.parallelize(data)
        data_shard = SparkXShards(rdd)
        shard2 = data_shard.repartition(6)
        partitions2 = shard2.rdd.glom().collect()
        for par in partitions2:
            assert len(par) <= 1
        shard3 = data_shard.repartition(1)
        partitions2 = shard3.rdd.glom().collect()
        for par in partitions2:
            assert len(par) <= 1

        # shards of numpy array
        data = [
            np.array([1, 2, 3, 4]),
            np.array([5, 6, 7, 8]),
            np.array([9, 10, 11, 12]),
            np.array([13, 14, 15, 16])
        ]
        sc = init_nncontext()
        rdd = sc.parallelize(data)
        data_shard = SparkXShards(rdd)
        shard2 = data_shard.repartition(6)
        partitions2 = shard2.rdd.glom().collect()
        for par in partitions2:
            assert len(par) <= 1
        shard3 = data_shard.repartition(1)
        partitions2 = shard3.rdd.glom().collect()
        for par in partitions2:
            assert len(par) <= 1
    def test_zip(self):
        def negative(df, column_name, minus_val):
            df[column_name] = df[column_name] * (-1)
            df[column_name] = df[column_name] - minus_val
            return df

        file_path = os.path.join(self.resource_path, "orca/data/json")
        data_shard = zoo.orca.data.pandas.read_json(file_path, orient='columns', lines=True)
        data_shard = data_shard.repartition(2)
        data_shard.cache()
        transformed_shard = data_shard.transform_shard(negative, "value", 2)
        zipped_shard = data_shard.zip(transformed_shard)
        assert not transformed_shard.is_cached(), "transformed_shard should be uncached."
        data = zipped_shard.collect()
        assert data[0][0]["value"].values[0] + data[0][1]["value"].values[0] == -2, \
            "value should be -2"
        list1 = [1, 2, 3]
        with self.assertRaises(Exception) as context:
            data_shard.zip(list1)
        self.assertTrue('other should be a SparkXShards' in str(context.exception))
        transformed_shard = transformed_shard.repartition(data_shard.num_partitions() - 1)
        with self.assertRaises(Exception) as context:
            data_shard.zip(transformed_shard)
        self.assertTrue('The two SparkXShards should have the same number of partitions' in
                        str(context.exception))
        dict_data = [{"x": 1, "y": 2}, {"x": 2, "y": 3}]
        sc = init_nncontext()
        rdd = sc.parallelize(dict_data)
        dict_shard = SparkXShards(rdd)
        dict_shard = dict_shard.repartition(1)
        with self.assertRaises(Exception) as context:
            transformed_shard.zip(dict_shard)
        self.assertTrue('The two SparkXShards should have the same number of elements in '
                        'each partition' in str(context.exception))