def _assert_successful_conversion(
        prot,
        graph_def,
        actual,
        *input_fns,
        decimals=3,
        **kwargs,  # pylint: disable=unused-argument
    ):
        converter = Converter(
            registry(),
            config=tfe.get_config(),
            protocol=prot,
            model_provider="model-provider",
        )
        x = converter.convert(graph_def, "input-provider", list(input_fns))

        with tfe.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if not isinstance(x, (list, tuple)):
                x = [x]
                actual = [actual]
            else:
                assert isinstance(
                    actual, (list, tuple)
                ), "expected output to be tensor sequence"
            try:
                output = sess.run([xi.reveal().decode() for xi in x], tag="reveal")
            except AttributeError:
                # assume all xi are all public
                output = sess.run(x, tag="reveal")
            for o_i, a_i in zip(output, actual):
                np.testing.assert_array_almost_equal(o_i, a_i, decimal=decimals)
Example #2
0
    def _assert_successful_conversion(prot,
                                      graph_def,
                                      actual,
                                      *input_fns,
                                      decimals=3,
                                      **kwargs):
        prot.clear_initializers()
        converter = Converter(tfe.get_config(), prot, 'model-provider')
        x = converter.convert(graph_def, registry(), 'input-provider',
                              list(input_fns))

        with tfe.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if not isinstance(x, (list, tuple)):
                x = [x]
                actual = [actual]
            else:
                assert isinstance(
                    actual,
                    (list, tuple)), "expected output to be tensor sequence"
            try:
                output = sess.run([xi.reveal().decode() for xi in x],
                                  tag='reveal')
            except AttributeError:
                # assume all xi are all public
                output = sess.run([xi for xi in x], tag='reveal')
            for o_i, a_i in zip(output, actual):
                np.testing.assert_array_almost_equal(o_i,
                                                     a_i,
                                                     decimal=decimals)
Example #3
0
 def test_empty_model(self):
     test_input = np.ones([1, 8, 8, 1])
     graph_def, prot_class = self._construct_empty_conversion_test(
         'empty_model', protocol='SecureNN')
     with prot_class() as prot:
         input_fn = self.ndarray_input_fn(test_input)
         prot.clear_initializers()
         converter = Converter(
             registry(),
             config=tfe.get_config(),
             protocol=prot,
             model_provider='model-provider',
         )
         self.assertRaises(ValueError, converter.convert, graph_def,
                           'input-provider', input_fn)
 def test_empty_model(self):
     test_input = np.ones([1, 8, 8, 1])
     graph_def, prot_class = self._construct_empty_conversion_test(
         "empty_model", protocol="SecureNN"
     )
     with prot_class() as prot:
         input_fn = self.ndarray_input_fn(test_input)
         converter = Converter(
             registry(),
             config=tfe.get_config(),
             protocol=prot,
             model_provider="model-provider",
         )
         self.assertRaises(
             ValueError, converter.convert, graph_def, "input-provider", input_fn,
         )
Example #5
0
    'server0', 'server1', 'crypto-producer', 'prediction-client',
    'weights-provider'
])


def provide_input() -> tf.Tensor:
    return tf.constant(np.random.normal(size=(1, 1, 28, 28)), tf.float32)


def receive_output(tensor: tf.Tensor) -> tf.Tensor:
    tf.print(tensor, [tensor])
    return tensor


with tfe.protocol.Pond(
        *config.get_players('server0, server1, crypto-producer')) as prot:

    c = convert.Converter(config, prot, config.get_player('weights-provider'))
    x = c.convert(graph_def, registry(),
                  config.get_player('prediction-client'), provide_input)

    prediction_op = prot.define_output(config.get_player('prediction-client'),
                                       x, receive_output)

    with tfe.Session(config=config) as sess:
        sess.run(tfe.global_variables_initializer(), tag='init')

        sess.run(prediction_op, tag='prediction')

os.remove(model_filename)