def get_private_reader(self, *ignore, metadata, privacy, database, **kwargs): if database not in self.connections: return None else: from snsql.sql import PrivateReader conn = self.connections[database] if self.engine.lower() == "spark": if database.lower() != 'pums_large': conn.createOrReplaceTempView("PUMS") conn = self.session priv = PrivateReader.from_connection( conn, metadata=metadata, privacy=privacy ) if self.engine.lower() == "spark": priv.reader.compare.search_path = ["PUMS"] return priv
from snsql.sql.privacy import Privacy from snsql.sql.parse import QueryParser git_root_dir = subprocess.check_output( "git rev-parse --show-toplevel".split(" ")).decode("utf-8").strip() meta_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS_pid.yaml")) csv_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS_pid.csv")) meta = Metadata.from_file(meta_path) pums = pd.read_csv(csv_path) query = 'SELECT AVG(age), STD(age), VAR(age), SUM(age), COUNT(age) FROM PUMS.PUMS GROUP BY sex' q = QueryParser(meta).query(query) privacy = Privacy(alphas=[0.01, 0.05], delta=1 / (math.sqrt(100) * 100)) priv = PrivateReader.from_connection(pums, privacy=privacy, metadata=meta) subquery, root = priv._rewrite(query) acc = Accuracy(root, subquery, privacy) class TestAccuracy: def test_count_accuracy(self): error = acc.count(alpha=0.05) assert (error < 7.53978 and error > 0.5) error_wide = acc.count(alpha=0.01) assert (error_wide < 9.909) assert (error_wide > error) def test_count_accuracy_small_delta(self): acc = Accuracy(root, subquery, privacy=Privacy(epsilon=1.0, delta=0.1))