Beispiel #1
0
    def test_stack_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "stack.pb"

        input1 = np.array([1, 4])
        input2 = np.array([2, 5])
        input3 = np.array([3, 6])

        path = export_stack(global_filename, input1.shape)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        actual = run_stack(input1, input2, input3)

        tf.reset_default_graph()

        config = tfe.LocalConfig([
            'server0',
            'server1',
            'crypto_producer',
            'prediction_client',
            'weights_provider',
        ])

        with tfe.protocol.Pond(*config.get_players(
                'server0, server1, crypto_producer')) as prot:
            prot.clear_initializers()

            class PredictionClient(tfe.io.InputProvider):
                self.input = None

                def provide_input(self):
                    return tf.constant(self.input)

            i1 = PredictionClient(config.get_player('prediction_client'))
            i1.input = input1
            i2 = PredictionClient(config.get_player('prediction_client'))
            i2.input = input2
            i3 = PredictionClient(config.get_player('prediction_client'))
            i3.input = input3

            input = [i1, i2, i3]

            converter = Converter(config, prot,
                                  config.get_player('weights_provider'))

            x = converter.convert(graph_def, input, register())

            with config.session() as sess:
                tfe.run(sess, prot.initializer, tag='init')

                output = x.reveal().eval(sess, tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
Beispiel #2
0
    def test_strided_slice_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "strided_slice.pb"

        path = export_strided_slice(global_filename)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        input = [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]]

        actual = run_strided_slice(input)

        tf.reset_default_graph()

        config = tfe.LocalConfig([
            'server0',
            'server1',
            'crypto_producer',
            'prediction_client',
            'weights_provider',
        ])

        with tfe.protocol.Pond(*config.get_players(
                'server0, server1, crypto_producer')) as prot:
            prot.clear_initializers()

            class PredictionClient(tfe.io.InputProvider):
                def provide_input(self):
                    return tf.constant([[[1, 1, 1], [2, 2, 2]],
                                        [[3, 3, 3], [4, 4, 4]],
                                        [[5, 5, 5], [6, 6, 6]]])

            input = PredictionClient(config.get_player('prediction_client'))

            converter = Converter(config, prot,
                                  config.get_player('weights_provider'))

            x = converter.convert(graph_def, input, register())

            with config.session() as sess:
                tfe.run(sess, prot.initializer, tag='init')

                output = x.reveal().eval(sess, tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
Beispiel #3
0
    def test_avgpooling_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "avgpool.pb"

        input_shape = [1, 28, 28, 1]

        path = export_avgpool(global_filename, input_shape)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        actual = run_avgpool(input_shape)

        tf.reset_default_graph()

        config = tfe.LocalConfig([
            'server0',
            'server1',
            'crypto_producer',
            'prediction_client',
            'weights_provider',
        ])

        with tfe.protocol.Pond(*config.get_players(
                'server0, server1, crypto_producer')) as prot:
            prot.clear_initializers()

            class PredictionClient(tfe.io.InputProvider):
                def provide_input(self):
                    return tf.constant(np.ones(input_shape))

            input = PredictionClient(config.get_player('prediction_client'))

            converter = Converter(config, prot,
                                  config.get_player('weights_provider'))

            x = converter.convert(graph_def, input, register())

            with config.session() as sess:
                tfe.run(sess, prot.initializer, tag='init')

                output = x.reveal().eval(sess, tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)