コード例 #1
0
    def _assert_successful_conversion(
        prot,
        graph_def,
        actual,
        *input_fns,
        decimals=3,
        **kwargs,  # pylint: disable=unused-argument
    ):
        converter = Converter(
            registry(),
            config=tfe.get_config(),
            protocol=prot,
            model_provider="model-provider",
        )
        x = converter.convert(graph_def, "input-provider", list(input_fns))

        with tfe.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if not isinstance(x, (list, tuple)):
                x = [x]
                actual = [actual]
            else:
                assert isinstance(
                    actual, (list, tuple)
                ), "expected output to be tensor sequence"
            try:
                output = sess.run([xi.reveal().decode() for xi in x], tag="reveal")
            except AttributeError:
                # assume all xi are all public
                output = sess.run(x, tag="reveal")
            for o_i, a_i in zip(output, actual):
                np.testing.assert_array_almost_equal(o_i, a_i, decimal=decimals)
コード例 #2
0
ファイル: test_convert.py プロジェクト: voidxb/tf-encrypted
    def _assert_successful_conversion(prot,
                                      graph_def,
                                      actual,
                                      *input_fns,
                                      decimals=3,
                                      **kwargs):
        prot.clear_initializers()
        converter = Converter(tfe.get_config(), prot, 'model-provider')
        x = converter.convert(graph_def, registry(), 'input-provider',
                              list(input_fns))

        with tfe.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if not isinstance(x, (list, tuple)):
                x = [x]
                actual = [actual]
            else:
                assert isinstance(
                    actual,
                    (list, tuple)), "expected output to be tensor sequence"
            try:
                output = sess.run([xi.reveal().decode() for xi in x],
                                  tag='reveal')
            except AttributeError:
                # assume all xi are all public
                output = sess.run([xi for xi in x], tag='reveal')
            for o_i, a_i in zip(output, actual):
                np.testing.assert_array_almost_equal(o_i,
                                                     a_i,
                                                     decimal=decimals)
コード例 #3
0
    def _assert_successful_conversion(prot, graph_def, actual, *input_fns, **kwargs):
        prot.clear_initializers()

        converter = Converter(tfe.get_config(), prot, 'model-provider')
        x = converter.convert(graph_def, register(), 'input-provider', list(input_fns))

        with tfe.Session() as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(x.reveal(), tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
コード例 #4
0
    def test_stack_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "stack.pb"

        input1 = np.array([1, 4])
        input2 = np.array([2, 5])
        input3 = np.array([3, 6])

        path = export_stack(global_filename, input1.shape)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        actual = run_stack(input1, input2, input3)

        tf.reset_default_graph()

        with tfe.protocol.Pond() as prot:
            prot.clear_initializers()

            def provide_input1() -> tf.Tensor:
                return tf.constant(input1)

            def provide_input2() -> tf.Tensor:
                return tf.constant(input2)

            def provide_input3() -> tf.Tensor:
                return tf.constant(input3)

            inputs = [provide_input1, provide_input2, provide_input3]

            converter = Converter(tfe.get_config(), prot, 'model-provider')

            x = converter.convert(graph_def, register(), 'input-provider',
                                  inputs)

            with tfe.Session() as sess:
                sess.run(prot.initializer, tag='init')

                output = sess.run(x.reveal(), tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
コード例 #5
0
    def test_strided_slice_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "strided_slice.pb"

        path = export_strided_slice(global_filename)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        input = [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]]

        actual = run_strided_slice(input)

        tf.reset_default_graph()

        with tfe.protocol.Pond() as prot:
            prot.clear_initializers()

            def provide_input():
                return tf.constant([[[1, 1, 1], [2, 2, 2]],
                                    [[3, 3, 3], [4, 4, 4]],
                                    [[5, 5, 5], [6, 6, 6]]])

            converter = Converter(tfe.get_config(), prot, 'model-provider')

            x = converter.convert(graph_def, register(), 'input-provider',
                                  provide_input)

            with tfe.Session() as sess:
                sess.run(prot.initializer, tag='init')

                output = sess.run(x.reveal(), tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
コード例 #6
0
    def test_mul_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "mul.pb"

        input_shape = [4, 1]

        path = export_mul(global_filename, input_shape)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        actual = run_mul(input_shape)

        tf.reset_default_graph()

        with tfe.protocol.Pond() as prot:
            prot.clear_initializers()

            def provide_input():
                return tf.constant(
                    np.array([1.0, 2.0, 3.0, 4.0]).reshape(input_shape))

            converter = Converter(tfe.get_config(), prot, 'model-provider')

            x = converter.convert(graph_def, register(), 'input-provider',
                                  provide_input)

            with tfe.Session() as sess:
                sess.run(prot.initializer, tag='init')

                output = sess.run(x.reveal(), tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
コード例 #7
0
    def test_cnn_NHWC_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "cnn_nhwc.pb"

        input_shape = [1, 28, 28, 1]

        path = export_cnn(global_filename, input_shape, data_format="NHWC")

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        actual = run_cnn(input_shape, data_format="NHWC")

        tf.reset_default_graph()

        with tfe.protocol.Pond() as prot:
            prot.clear_initializers()

            def provide_input():
                return tf.constant(np.ones(input_shape))

            converter = Converter(tfe.get_config(), prot, 'model-provider')

            x = converter.convert(graph_def, register(), 'input-provider',
                                  provide_input)

            with tfe.Session() as sess:
                sess.run(prot.initializer, tag='init')

                output = sess.run(x.reveal(), tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
コード例 #8
0
    def test_argmax_convert(self):
        tf.reset_default_graph()

        global global_filename
        global_filename = "argmax.pb"

        input_shape = [5]
        input = [1, 2, 3, 4, 5]

        path = export_argmax(global_filename, input_shape, 0)

        tf.reset_default_graph()

        graph_def = read_graph(path)

        tf.reset_default_graph()

        actual = run_argmax(input, 0)

        tf.reset_default_graph()

        with tfe.protocol.SecureNN() as prot:
            prot.clear_initializers()

            def provide_input():
                return tf.constant(np.ones(input_shape))

            converter = Converter(tfe.get_config(), prot, 'model-provider')

            x = converter.convert(graph_def, register(), 'input-provider', provide_input)

            with tfe.Session() as sess:
                sess.run(tf.global_variables_initializer())

                output = sess.run(x.reveal(), tag='reveal')

        np.testing.assert_array_almost_equal(output, actual, decimal=3)
コード例 #9
0
 def test_empty_model(self):
     test_input = np.ones([1, 8, 8, 1])
     graph_def, prot_class = self._construct_empty_conversion_test(
         'empty_model', protocol='SecureNN')
     with prot_class() as prot:
         input_fn = self.ndarray_input_fn(test_input)
         prot.clear_initializers()
         converter = Converter(
             registry(),
             config=tfe.get_config(),
             protocol=prot,
             model_provider='model-provider',
         )
         self.assertRaises(ValueError, converter.convert, graph_def,
                           'input-provider', input_fn)
コード例 #10
0
 def test_empty_model(self):
     test_input = np.ones([1, 8, 8, 1])
     graph_def, prot_class = self._construct_empty_conversion_test(
         "empty_model", protocol="SecureNN"
     )
     with prot_class() as prot:
         input_fn = self.ndarray_input_fn(test_input)
         converter = Converter(
             registry(),
             config=tfe.get_config(),
             protocol=prot,
             model_provider="model-provider",
         )
         self.assertRaises(
             ValueError, converter.convert, graph_def, "input-provider", input_fn,
         )