Beispiel #1
0
    def test_iterate_private(self):
        tf.reset_default_graph()

        prot = ABY3()
        tfe.set_protocol(prot)

        def provide_input():
            return tf.reshape(tf.range(0, 8), [4, 2])

        # define inputs
        x = tfe.define_private_input("input-provider", provide_input)

        _, tmp_filename = tempfile.mkstemp()
        write_op = x.write(tmp_filename)

        with tfe.Session() as sess:
            # initialize variables
            sess.run(tfe.global_variables_initializer())
            # reveal result
            sess.run(write_op)

        x = tfe.read(tmp_filename, batch_size=5, n_columns=2)
        y = tfe.iterate(x, batch_size=3, repeat=True, shuffle=False)
        z = tfe.iterate(x, batch_size=3, repeat=True, shuffle=True)
        with tfe.Session() as sess:
            sess.run(tfe.global_variables_initializer())
            # TODO: fix this test
            print(sess.run(x.reveal()))
            print(sess.run(y.reveal()))
            print(sess.run(y.reveal()))
            print(sess.run(x.reveal()))
            print(sess.run(z.reveal()))

        os.remove(tmp_filename)
Beispiel #2
0
    def test_iterate_private(self):
        tf.reset_default_graph()

        prot = ABY3()
        tfe.set_protocol(prot)

        def provide_input():
            return tf.reshape(tf.range(0, 8), [4, 2])

        # define inputs
        x = tfe.define_private_input('input-provider', provide_input)

        write_op = x.write("x.tfrecord")

        with tfe.Session() as sess:
            # initialize variables
            sess.run(tfe.global_variables_initializer())
            # reveal result
            sess.run(write_op)

        x = tfe.read("x.tfrecord", batch_size=5, n_columns=2)
        y = tfe.iterate(x, batch_size=3, repeat=True, shuffle=False)
        z = tfe.iterate(x, batch_size=3, repeat=True, shuffle=True)
        with tfe.Session() as sess:
            sess.run(tfe.global_variables_initializer())
            print(sess.run(x.reveal()))
            print(sess.run(y.reveal()))
            print(sess.run(y.reveal()))
            print(sess.run(x.reveal()))
            print(sess.run(z.reveal()))
Beispiel #3
0
    def test_read_private(self):

        tf.reset_default_graph()

        prot = ABY3()
        tfe.set_protocol(prot)

        def provide_input():
            return tf.reshape(tf.range(0, 8), [4, 2])

        # define inputs
        x = tfe.define_private_input("input-provider", provide_input)

        _, tmp_filename = tempfile.mkstemp()
        write_op = x.write(tmp_filename)

        with tfe.Session() as sess:
            # initialize variables
            sess.run(tfe.global_variables_initializer())
            # reveal result
            sess.run(write_op)

        x = tfe.read(tmp_filename, batch_size=5, n_columns=2)
        with tfe.Session() as sess:
            result = sess.run(x.reveal())
            np.testing.assert_allclose(
                result,
                np.array(list(range(0, 8)) + [0, 1]).reshape([5, 2]),
                rtol=0.0,
                atol=0.01,
            )

        os.remove(tmp_filename)
Beispiel #4
0
    def test_read_private(self):

        tf.reset_default_graph()

        prot = ABY3()
        tfe.set_protocol(prot)

        def provide_input():
            return tf.reshape(tf.range(0, 8), [4, 2])

        # define inputs
        x = tfe.define_private_input('input-provider', provide_input)

        write_op = x.write("x.tfrecord")

        with tfe.Session() as sess:
            # initialize variables
            sess.run(tfe.global_variables_initializer())
            # reveal result
            sess.run(write_op)

        x = tfe.read("x.tfrecord", batch_size=5, n_columns=2)
        with tfe.Session() as sess:
            result = sess.run(x.reveal())
            np.testing.assert_allclose(result,
                                       np.array(list(range(0, 8)) +
                                                [0, 1]).reshape([5, 2]),
                                       rtol=0.0,
                                       atol=0.01)
            print("test_read_private succeeds")