示例#1
0
def initialization(*args):
    """
    :param:tuple : (b_context, b_X, b_X,keys)
            b_context: binary representation of the context_. context_.serialize()$
            b_X : list of binary representations of samples from CKKS vectors format
            b_Y : list of binary representations of labels from CKKS vectors format
            keys : keys of the samples which are passed to the subprocess. the local b_X[i] is the global X[keys[i]]. Useful to map predictions to true labels 
    This function is the first one to be passed in the input_queue queue of the process.
    It first deserialize the context_, passing it global,
    in the memory space allocated to the process
    Then the batch is also deserialize, using the context_,
    to generate a list of CKKS vector which stand for the encrypted samples on which the process will work
    """
    b_context = args[0]
    b_X = args[1]
    b_Y = args[2]
    global context_
    context_ = ts.context_from(b_context)
    global local_X_
    global local_Y_
    local_X_ = [ts.ckks_vector_from(context_, i) for i in b_X]
    local_Y_ = [ts.ckks_vector_from(context_, i) for i in b_Y]
    global local_keys
    local_keys = args[3]
    return 'Initialization done for process %s. Len of data : %i' % (
        multiprocessing.current_process().name, len(local_X_))
示例#2
0
    def get_context(self, ctx_id: str,) -> ts._ts_cpp.TenSEALContext:
        """Get a previously registered context using a context_id

        Args:
            ctx_id: id of a previously registered context

        Returns:
            TenSEALContext

        Raises:
            ConnectionError: if a connection can't be established with the API
            ResourceNotFound: if the context identified with `ctx_id` can't be found
            Answer418: if response.status_code is 418
            ServerError: if response.status_code is 500
        """

        url = self._base_url + f"/contexts/"
        data = {"context_id": ctx_id}

        try:
            response = requests.get(url, params=data)
        except requests.exceptions.ConnectionError:
            raise ConnectionError

        if response.status_code != 200:
            Client._handle_error_response(response)

        ser_ctx = response.json()["context"]
        ctx = ts.context_from(b64decode(ser_ctx))

        return ctx
示例#3
0
def test_serialization_parameters_drop_publickey(encryption_type):
    orig_context = ctx(encryption_type)
    orig_context.generate_galois_keys()

    assert orig_context.has_relin_keys()
    assert orig_context.has_galois_keys()
    assert orig_context.has_secret_key()

    if encryption_type is ts.ENCRYPTION_TYPE.ASYMMETRIC:
        assert orig_context.has_public_key()

    # drop public key
    proto = orig_context.serialize(save_public_key=False,
                                   save_secret_key=True,
                                   save_galois_keys=True,
                                   save_relin_keys=True)
    nctx = ts.context_from(proto)

    assert nctx.has_relin_keys()
    assert nctx.has_galois_keys()
    assert nctx.has_secret_key()

    if encryption_type is ts.ENCRYPTION_TYPE.ASYMMETRIC:
        assert not nctx.has_public_key()

    # drop public key and secret key
    proto = orig_context.serialize(save_public_key=False,
                                   save_secret_key=False,
                                   save_galois_keys=True,
                                   save_relin_keys=True)
    nctx = ts.context_from(proto)

    assert nctx.has_relin_keys()
    assert nctx.has_galois_keys()
    assert not nctx.has_secret_key()

    if encryption_type is ts.ENCRYPTION_TYPE.ASYMMETRIC:
        assert not nctx.has_public_key()
示例#4
0
def context_proto2object(proto: VendorBytes_PB) -> ts.Context:
    vendor_lib = proto.vendor_lib
    lib_version = version.parse(proto.vendor_lib_version)

    if vendor_lib not in sys.modules:
        traceback_and_raise(
            Exception(f"{vendor_lib} version: {proto.vendor_lib_version} is required")
        )
    else:
        if lib_version > version.parse(ts.__version__):
            log = f"Warning {lib_version} > local imported version {ts.__version__}"
            info(log)

    return ts.context_from(proto.content, n_threads=1)
    def prepare_input(context: bytes, ckks_vector: bytes) -> ts._ts_cpp.CKKSVector:
        # TODO: check parameters or size and raise InvalidParameters when needed
        try:
            ctx = ts.context_from(context)
            enc_x = ts.ckks_vector_from(ctx, ckks_vector)
        except:
            raise DeserializationError("cannot deserialize context or ckks_vector")

        # TODO: replace this with a more flexible check when introduced in the API
        try:
            _ = ctx.galois_keys()
        except:
            raise InvalidContext("the context doesn't hold galois keys")

        return enc_x
示例#6
0
    def prepare_input(context, ckks_vector):
        try:
            ctx = ts.context_from(context)
            enc_x = ts.ckks_vector_from(ctx, ckks_vector)
        except:
            raise DeserializationError(
                "cannot deserialize context or ckks_vector")

        # TODO: replace this with a more flexible check when introduced in the API
        try:
            _ = ctx.galois_keys()
        except:
            raise InvalidContext("the context doesn't hold galois keys")

        return enc_x
示例#7
0
def load_ctx_and_input(
    context_file: typer.FileBinaryRead,
    input_file: typer.FileBinaryRead = None
) -> Tuple[ts._ts_cpp.TenSEALContext, ts._ts_cpp.CKKSVector]:
    try:
        ctx = ts.context_from(context_file.read())
    except Exception as e:
        typer.echo(f"Couldn't load context: {str(e)}", err=True)
        raise typer.Exit(code=1)

    # only load context
    if input_file is None:
        return ctx, None

    try:
        enc_input = ts.ckks_vector_from(ctx, input_file.read())
    except Exception as e:
        typer.echo(f"Couldn't load encrypted input: {str(e)}", err=True)
        raise typer.Exit(code=1)
    return ctx, enc_input
示例#8
0
def retrieve_data(my_address, params, index, url):
    response = post(url,
                    json={
                        'address': my_address,
                        'index': index,
                        'params': params
                    })
    json_obj = response.json()
    try:
        ctx_string = json_obj['ctx']
        ctx = context_from(b64decode(ctx_string))
        data_string = json_obj['ckks']
        data = client_deserialize(ctx, data_string)

        column_string = json_obj['col']
        column = client_deserialize(ctx, column_string)

        set_additional_methods(_ts_cpp.CKKSVector)
        return data, column, ctx
    except:
        return json_obj, json_obj, json_obj
示例#9
0
    def __init__(self,
                 init_weight,
                 init_bias,
                 refresh_function,
                 context,
                 confidential_kwarg,
                 learning_rate=1,
                 max_epoch=100,
                 reg_para=0.5,
                 verbose=-1,
                 save_weight=-1,
                 n_jobs=1):
        """

            Constructor


            :param context: tenseal context_. Hold the public key, the relin key and the galois key. Those are mandatory to make computation and deserialization
            :param n_jobs: multiprocessing. Equal to the number of processes that will be created and launched
            :param init_weight: CKKS vector. Initial weight
            :param init_bias: CKKS vector. Initial weight
            :param refresh_function: function. Refresh ciphertext
            :param confidential_kwarg: dict. Will be passed as **kwarg to refresh, loss and accuracy functions. Contain confidential data which are needed by those functions.
            :param learning_rate: float. learning rate
            :param max_epoch: int. number of epoch to be performed
            :param reg_para: float. regularization parameter
            :param verbose: int. number of epoch were the loss is not computed, nor printed.
                            Every <verbose> epoch, the loss (and error) will be logged
                            If set to -1, the loss will not be computed nor stored at all

            :param save_weight: int. number of epoch were the weight will be stored.
                                Every <save_weight> epoch, the weight will be logged in weight_list
                                If set to -1, the weight will not be saved

        """

        self.logger = logging.getLogger(__name__)

        self.refresh_function = refresh_function
        self.confidential_kwarg = confidential_kwarg

        if type(context) is bytes:
            self.context = ts.context_from(context)
            self.b_context = context
        else:
            self.context = context
            self.b_context = context.serialize()

        self.verbose = verbose
        self.save_weight = save_weight

        self.iter = 0
        self.num_iter = max_epoch
        self.reg_para = reg_para
        self.lr = learning_rate

        self.n_jobs = n_jobs

        if verbose > 0:
            self.loss_list = []
        if save_weight > 0:
            self.weight_list = []
            self.bias_list = []
        self.weight = init_weight
        self.bias = init_bias
示例#10
0
import tenseal as ts
import socket
import pickle

client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(('localhost', 4472))

# Setting the public and private contexts
private_context = ts.context(ts.SCHEME_TYPE.BFV,
                             poly_modulus_degree=2**14,
                             plain_modulus=2**16 + 1)
public_context = ts.context_from(private_context.serialize())
public_context.make_context_public()

# I create the query to be sent to the server
query = '1001000'
plain_query = [int(query, 2)]
enc_query = ts.bfv_vector(public_context, plain_query)

# I prepare the message that I want to send to the server
enc_query_serialized = enc_query.serialize()
context_serialized = public_context.serialize()
message_to_be_sent = [context_serialized, enc_query_serialized]
message_to_be_sent_serialized = pickle.dumps(message_to_be_sent, protocol=None)

# Here is the length of my message
L = len(message_to_be_sent_serialized)
sL = str(L) + ' ' * (10 - len(str(L)))  #pad len to 10 bytes

# I first send the length of the message to the server
client.sendall((sL).encode())
for i in range(3):
    conn, addr = serv.accept()
    L = conn.recv(10).decode().strip()
    L = int(L, 10)

    # Getting bytes of context and encrypted query
    final_data = b""
    while len(final_data) < L:
        data = conn.recv(4096)
        if not data: break
        final_data += data
    deserialized_message = pickle.loads(final_data)

    # Here we recover the context and ciphertext received from the client
    context = ts.context_from(deserialized_message[0])
    ciphertext = deserialized_message[1]
    ct = ts.bfv_vector_from(context, ciphertext)

    # Evaluate the database polynomial at the ciphertext received from the client
    response = ct - [new_database[0]]
    for i in range(1, len_database):
        factor = ct - [new_database[i]]
        response = response * factor

    # Prepare the answer to be sent to the client
    response_serialized = response.serialize()
    response_to_be_sent = pickle.dumps(response_serialized, protocol=None)
    conn.sendall(response_to_be_sent)

    # Close the connection
示例#12
0
 def __init__(self):
     with open('public_key.bin', 'rb') as f:
         public_serial = f.read()
     self.public_context = ts.context_from(public_serial)
示例#13
0
def recreate(ctx):
    proto = ctx.serialize()
    return ts.context_from(proto)
示例#14
0
def load_context(file_path):
    import tenseal as ts

    with open(file_path, "rb") as f:
        return ts.context_from(f.read())
示例#15
0
import tenseal as ts
import time
import p2p

bin_name = 'vectorReceive.bin'
context_name = 'contextReceive.bin'

receiver = p2p.Receiver('127.0.0.1', 8080)
receiver.receive(bin_name)
receiver = p2p.Receiver('127.0.0.1', 8080)
receiver.receive(context_name)

with open(bin_name, 'rb') as f:
    vector_enc_bin = f.read()

with open(context_name, 'rb') as f:
    context_bin = f.read()

context = ts.context_from(context_bin)
vector_enc = ts.ckks_vector_from(context, vector_enc_bin)

vector = vector_enc.decrypt()

print("Received vector:", vector)
示例#16
0
def recreate(ctx):
    proto = ctx.serialize(save_secret_key=True)
    return ts.context_from(proto)
示例#17
0
def load_context(ctx_id: str) -> ts._ts_cpp.TenSEALContext:
    """Load a TenSEALContext"""
    context = get_raw_context(ctx_id)
    ctx = ts.context_from(context)
    return ctx
示例#18
0
 def deserialize(self):
     return ts.context_from(self.context_bytes)