예제 #1
0
def start_master(cluster_config_file=None):
    print("Starting alice...")
    remote_config = tfe.RemoteConfig.load(cluster_config_file)
    tfe.set_config(remote_config)
    tfe.set_protocol(tfe.protocol.Pond())
    players = remote_config.players
    server0 = remote_config.server(players[0].name)
    main(server0)
예제 #2
0
def start_slave(cluster_config_file):
    print("Starting bob...")
    remote_config = tfe.RemoteConfig.load(cluster_config_file)
    tfe.set_config(remote_config)
    tfe.set_protocol(tfe.protocol.Pond())
    players = remote_config.players
    bob = remote_config.server(players[1].name)
    print("server_name = ", players[1].name)
    bob.join()
def start_slave(cluster_config_file):
    print("Starting crypto producer...")
    remote_config = tfe.RemoteConfig.load(cluster_config_file)
    tfe.set_config(remote_config)
    tfe.set_protocol(tfe.protocol.Pond())
    players = remote_config.players
    server = remote_config.server(players[2].name)
    print("server_name = ", players[2].name)
    server.join()
예제 #4
0
def start_master(cluster_config_file=None):
    print("Starting alice...")
    remote_config = tfe.RemoteConfig.load(cluster_config_file)
    tfe.set_config(remote_config)
    tfe.set_protocol(tfe.protocol.Pond())
    players = remote_config.players
    server0 = remote_config.server(players[0].name)

    st = time.perf_counter()
    main(server0)
    ed = time.perf_counter()
    print(f'Elapsed time: {ed - st}s')
예제 #5
0
    def connect_to_model(self, input_shape, output_shape, *workers):
        config, _ = self.config_from_workers(workers)
        tfe.set_config(config)

        prot = tfe.protocol.SecureNN(config.get_player("server0"),
                                     config.get_player("server1"),
                                     config.get_player("server2"))
        tfe.set_protocol(prot)

        self._tf_client = tfe.serving.QueueClient(input_shape=input_shape,
                                                  output_shape=output_shape)

        sess = tfe.Session(config=config)
        self._tf_session = sess
예제 #6
0
def _configure_tfe(cluster):

    if not cluster or len(cluster.workers) != 3:
        raise RuntimeError(
            "TF Encrypted expects three parties for its sharing protocols.")

    config = cluster.tfe_config
    tfe.set_config(config)

    prot = tfe.protocol.SecureNN(config.get_player("server0"),
                                 config.get_player("server1"),
                                 config.get_player("server2"))
    tfe.set_protocol(prot)
    tfe.clear_initializers()
예제 #7
0
def _configure_tfe(workers):

    if not workers or len(workers) != 3:
        raise RuntimeError(
            "TF Encrypted expects three parties for its sharing protocols.")

    tfe_worker_cls = workers[0].__class__
    config, player_to_worker_mapping = tfe_worker_cls.config_from_workers(
        workers)
    tfe.set_config(config)

    prot = tfe.protocol.SecureNN(config.get_player("server0"),
                                 config.get_player("server1"),
                                 config.get_player("server2"))
    tfe.set_protocol(prot)

    return player_to_worker_mapping
예제 #8
0
    def connect_to_model(self, input_shape, output_shape, cluster, sess=None):
        """
        Connect to a TF Encrypted model being served by the given cluster.
        
        This must be done before querying the model.
        """

        config = cluster.tfe_config
        tfe.set_config(config)

        prot = tfe.protocol.SecureNN(config.get_player("server0"),
                                     config.get_player("server1"),
                                     config.get_player("server2"))
        tfe.set_protocol(prot)

        self._tf_client = tfe.serving.QueueClient(input_shape=input_shape,
                                                  output_shape=output_shape)

        if sess is None:
            sess = tfe.Session(config=config)
        self._tf_session = sess
예제 #9
0
# tfe.setMonitorStatsFlag(True)

if len(sys.argv) >= 2:
    # config file was specified
    config_file = sys.argv[1]
    config = tfe.config.load(config_file)
else:
    # default to using local config
    config = tfe.LocalConfig([
        'server0',
        'server1',
        'crypto-producer',
        'model-trainer',
        'prediction-client'
    ])
tfe.set_config(config)
tfe.set_protocol(tfe.protocol.SecureNN(*tfe.get_config().get_players(['server0', 'server1', 'crypto-producer'])))


def weight_variable(shape, gain):
    """weight_variable generates a weight variable of a given shape."""
    if len(shape) == 2:
        fan_in, fan_out = shape
    elif len(shape) == 4:
        h, w, c_in, c_out = shape
        fan_in = h * w * c_in
        fan_out = h * w * c_out
    r = gain * math.sqrt(6 / (fan_in + fan_out))
    initial = tf.random_uniform(shape, minval=-r, maxval=r)
    return tf.Variable(initial)
예제 #10
0
import numpy as np
import random as ran


def provide_data(features):
    dataset = tf.data.Dataset.from_tensor_slices(features)
    dataset = dataset.repeat()
    dataset = dataset.batch(10)
    iterator = dataset.make_one_shot_iterator()
    batch = iterator.get_next()
    batch = tf.reshape(batch, [10, 784])
    return batch


remote_config = tfe.RemoteConfig.load("config.json")
tfe.set_config(remote_config)

tfe.set_protocol(tfe.protocol.Pond())
players = remote_config.players
server0 = remote_config.server(players[0].name)

tfe.set_protocol(
    tfe.protocol.Pond(tfe.get_config().get_player("alice"),
                      tfe.get_config().get_player("bob")))

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
train_data = mnist.train.images[:100, :]
train_labels = mnist.train.labels[:100]

x_train_0 = tfe.define_private_input("alice", lambda: provide_data(train_data))