def test_stack_convert(self): tf.reset_default_graph() global global_filename global_filename = "stack.pb" input1 = np.array([1, 4]) input2 = np.array([2, 5]) input3 = np.array([3, 6]) path = export_stack(global_filename, input1.shape) tf.reset_default_graph() graph_def = read_graph(path) tf.reset_default_graph() actual = run_stack(input1, input2, input3) tf.reset_default_graph() config = tfe.LocalConfig([ 'server0', 'server1', 'crypto_producer', 'prediction_client', 'weights_provider', ]) with tfe.protocol.Pond(*config.get_players( 'server0, server1, crypto_producer')) as prot: prot.clear_initializers() class PredictionClient(tfe.io.InputProvider): self.input = None def provide_input(self): return tf.constant(self.input) i1 = PredictionClient(config.get_player('prediction_client')) i1.input = input1 i2 = PredictionClient(config.get_player('prediction_client')) i2.input = input2 i3 = PredictionClient(config.get_player('prediction_client')) i3.input = input3 input = [i1, i2, i3] converter = Converter(config, prot, config.get_player('weights_provider')) x = converter.convert(graph_def, input, register()) with config.session() as sess: tfe.run(sess, prot.initializer, tag='init') output = x.reveal().eval(sess, tag='reveal') np.testing.assert_array_almost_equal(output, actual, decimal=3)
def test_strided_slice_convert(self): tf.reset_default_graph() global global_filename global_filename = "strided_slice.pb" path = export_strided_slice(global_filename) tf.reset_default_graph() graph_def = read_graph(path) tf.reset_default_graph() input = [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]] actual = run_strided_slice(input) tf.reset_default_graph() config = tfe.LocalConfig([ 'server0', 'server1', 'crypto_producer', 'prediction_client', 'weights_provider', ]) with tfe.protocol.Pond(*config.get_players( 'server0, server1, crypto_producer')) as prot: prot.clear_initializers() class PredictionClient(tfe.io.InputProvider): def provide_input(self): return tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]) input = PredictionClient(config.get_player('prediction_client')) converter = Converter(config, prot, config.get_player('weights_provider')) x = converter.convert(graph_def, input, register()) with config.session() as sess: tfe.run(sess, prot.initializer, tag='init') output = x.reveal().eval(sess, tag='reveal') np.testing.assert_array_almost_equal(output, actual, decimal=3)
def test_avgpooling_convert(self): tf.reset_default_graph() global global_filename global_filename = "avgpool.pb" input_shape = [1, 28, 28, 1] path = export_avgpool(global_filename, input_shape) tf.reset_default_graph() graph_def = read_graph(path) tf.reset_default_graph() actual = run_avgpool(input_shape) tf.reset_default_graph() config = tfe.LocalConfig([ 'server0', 'server1', 'crypto_producer', 'prediction_client', 'weights_provider', ]) with tfe.protocol.Pond(*config.get_players( 'server0, server1, crypto_producer')) as prot: prot.clear_initializers() class PredictionClient(tfe.io.InputProvider): def provide_input(self): return tf.constant(np.ones(input_shape)) input = PredictionClient(config.get_player('prediction_client')) converter = Converter(config, prot, config.get_player('weights_provider')) x = converter.convert(graph_def, input, register()) with config.session() as sess: tfe.run(sess, prot.initializer, tag='init') output = x.reveal().eval(sess, tag='reveal') np.testing.assert_array_almost_equal(output, actual, decimal=3)
export_cnn() tf.reset_default_graph() model_filename = 'cnn.pb' with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) config = tfe.LocalConfig([ 'server0', 'server1', 'crypto_producer', 'prediction_client', 'weights_provider' ]) with tfe.protocol.Pond( *config.get_players('server0, server1, crypto_producer')) as prot: input = PredictionInputProvider(config.get_player('prediction_client')) output = PredictionOutputReceiver(config.get_player('prediction_client')) c = convert.Converter(config, prot, config.get_player('weights_provider')) x = c.convert(graph_def, input, register()) prediction_op = prot.define_output(x, output) with config.session() as sess: tfe.run(sess, prot.initializer, tag='init') tfe.run(sess, prediction_op, tag='prediction')