예제 #1
0
def encode_shares(
    batch_id,
    n_data,
    public_key_hex_internal,
    public_key_hex_external,
    input,
    output_a,
    output_b,
):
    click.echo("Running encode shares")
    spark = spark_session()
    shares = (spark.read.json(input).withColumn(
        "pid", F.spark_partition_id()).groupBy("pid").applyInPandas(
            lambda pdf: udf.encode(batch_id, n_data, public_key_hex_internal,
                                   public_key_hex_external, pdf),
            schema="a: binary, b: binary",
        ).withColumn("id",
                     F.udf(lambda: str(uuid4()), returnType="string")()))
    shares.cache()
    row = shares.first()
    dataset_estimate_mb = ((len(b64encode(row.shares.a)) + len(str(uuid4()))) *
                           n_rows * scale * 1.0 / 10**6)
    num_partitions = math.ceil(dataset_estimate_mb / partition_size_mb)
    click.echo(f"writing {num_partitions} partitions")
    repartitioned = shares.repartitionByRange(num_partitions, "id").cache()
    repartitioned.select("id",
                         F.base64("a").alias("payload")).write.json(
                             output_a, mode="overwrite")
    repartitioned.select("id",
                         F.base64("b").alias("payload")).write.json(
                             output_b, mode="overwrite")
예제 #2
0
def tokenize(column):
    """
    Thie function is to tokenize base64
    :param column: pass column to be tokensize
    :return: tokenized column
    """
    return fn.base64(column)
예제 #3
0
def test_encode_bad_data(spark, root, args):
    # 2 good fields and 1 bad field
    data = [{"payload": [1 for _ in range(args.n_data)]} for _ in range(2)] + [
        {"payload": [-1 for _ in range(args.n_data + 1)]}
    ]
    df = spark.createDataFrame(data)
    df.show()
    transformed = df.select(
        F.pandas_udf(
            partial(
                udf.encode,
                args.batch_id,
                args.n_data,
                args.public_key_hex_internal,
                args.public_key_hex_external,
            ),
            returnType="a: binary, b: binary",
        )("payload").alias("shares")
    ).select(F.base64("shares.a").alias("a"), F.base64("shares.b").alias("b"))
    transformed.show()
    assert transformed.where("a IS NOT NULL").count() == 2
    assert transformed.where("a IS NULL").count() == 1
예제 #4
0
def test_encode(spark, root, args):
    df = spark.read.json(str(root / "client"))
    transformed = df.select(
        F.pandas_udf(
            partial(
                udf.encode,
                args.batch_id,
                args.n_data,
                args.public_key_hex_internal,
                args.public_key_hex_external,
            ),
            returnType="a: binary, b: binary",
        )("payload").alias("shares")
    ).select(F.base64("shares.a").alias("a"), F.base64("shares.b").alias("b"))

    # assert the shares are all the same length
    server_a_payload_len = (
        spark.read.json(str(root / "server_a" / "raw"))
        .select((F.expr("avg(length(payload))")).alias("len"))
        .first()
        .len
    )
    # jq '.payload | length' tests/resources/cli/server_a/raw/data.ndjson -> 396
    assert server_a_payload_len == 396
    assert (
        transformed.where(f"abs(length(a) - {server_a_payload_len}) > 1").count() == 0
    )

    server_b_payload_len = (
        spark.read.json(str(root / "server_b" / "raw"))
        .select((F.expr("avg(length(payload))")).alias("len"))
        .first()
        .len
    )
    # jq '.payload | length' tests/resources/cli/server_a/raw/data.ndjson -> 208
    assert server_b_payload_len == 208
    assert (
        transformed.where(f"abs(length(b) - {server_b_payload_len}) > 1").count() == 0
    )
예제 #5
0
def aggregate(
    batch_id,
    n_data,
    server_id,
    private_key_hex,
    shared_secret,
    public_key_hex_internal,
    public_key_hex_external,
    input,
    input_internal,
    input_external,
    output,
):
    """Generate an aggregate share from a batch of verified SNIPs"""
    click.echo("Running aggregate")
    spark = spark_session()
    shares = spark.read.json(input)
    internal = spark.read.json(input_internal)
    external = spark.read.json(input_external)

    args = [
        batch_id,
        n_data,
        server_id,
        private_key_hex,
        b64decode(shared_secret),
        public_key_hex_internal,
        public_key_hex_external,
    ]
    (shares.join(internal.withColumnRenamed("payload", "internal"),
                 on="id").join(
                     external.withColumnRenamed("payload", "external"),
                     on="id").select(
                         F.unbase64("payload").alias("shares"),
                         F.unbase64("internal").alias("internal"),
                         F.unbase64("external").alias("external"),
                         F.spark_partition_id().alias("pid"),
                     ).groupBy("pid").applyInPandas(
                         lambda pdf: udf.aggregate(*args, pdf),
                         schema="payload: binary, error: int, total: int",
                     ).groupBy().applyInPandas(
                         lambda pdf: udf.total_share(*args, pdf),
                         schema="payload: binary, error: int, total: int",
                     ).withColumn("payload", F.base64("payload"))).write.json(
                         output, mode="overwrite")
def transform_col_binary(data_frame):
    """
    Method for transform column binary to base64

    Parameters
    ----------
    data_frame: DataFrame
        dataframe to be transformed
        
    Returns
    -------
    dataframe
        A transformed dataframe with new column base64

    """
    return reduce(
        lambda df,
        (col_name, dtype): df.withColumn(col_name, base64(col(
            col_name))).withColumnRenamed(col_name, 'BASE64_' + col_name)
        if dtype == 'binary' else df.withColumn(col_name, col(col_name)),
        data_frame.dtypes, data_frame)
예제 #7
0
def test_verify2(spark, root, args):
    raw = spark.read.json(str(root / "server_a" / "raw"))
    internal = spark.read.json(
        str(root / "server_a" / "intermediate" / "internal" / "verify1")
    )
    external = spark.read.json(
        str(root / "server_a" / "intermediate" / "external" / "verify1")
    )

    actual = (
        raw.select("id", F.unbase64("payload").alias("shares"))
        .join(internal.select("id", F.unbase64("payload").alias("internal")), on="id")
        .join(external.select("id", F.unbase64("payload").alias("external")), on="id")
        .select(
            "id",
            F.base64(
                F.pandas_udf(
                    partial(
                        udf.verify2,
                        args.batch_id,
                        args.n_data,
                        args.server_id,
                        args.private_key_hex,
                        args.shared_secret,
                        args.public_key_hex_internal,
                        args.public_key_hex_external,
                    ),
                    returnType="binary",
                )("shares", "internal", "external")
            ).alias("expected_payload"),
        )
    )

    expected = spark.read.json(
        str(root / "server_a" / "intermediate" / "internal" / "verify2")
    )

    joined = actual.join(expected, on="id")
    assert joined.where("length(expected_payload) <> length(payload)").count() == 0
예제 #8
0
def test_verify1(spark, root, args):
    df = spark.read.json(str(root / "server_a" / "raw"))
    df.show(vertical=True, truncate=100)

    actual = df.select(
        "id",
        F.base64(
            F.pandas_udf(
                partial(
                    udf.verify1,
                    args.batch_id,
                    args.n_data,
                    args.server_id,
                    args.private_key_hex,
                    args.shared_secret,
                    args.public_key_hex_internal,
                    args.public_key_hex_external,
                ),
                returnType="binary",
            )(F.unbase64("payload"))
        ).alias("expected_payload"),
    )

    expected = spark.read.json(
        str(root / "server_a" / "intermediate" / "internal" / "verify1")
    )

    joined = actual.join(expected, on="id")
    joined.show(vertical=True, truncate=100)

    # NOTE: Payloads are only the same if they are processed in a deterministic
    # order using the same context due to the pseudorandom seed. The CLI
    # application assumes the same server context across all of the rows in a
    # partition. However, the UDF approach will generate a new server context
    # for every row.
    assert joined.where("length(expected_payload) <> length(payload)").count() == 0