def test_iterate_private(self): tf.reset_default_graph() prot = ABY3() tfe.set_protocol(prot) def provide_input(): return tf.reshape(tf.range(0, 8), [4, 2]) # define inputs x = tfe.define_private_input("input-provider", provide_input) _, tmp_filename = tempfile.mkstemp() write_op = x.write(tmp_filename) with tfe.Session() as sess: # initialize variables sess.run(tfe.global_variables_initializer()) # reveal result sess.run(write_op) x = tfe.read(tmp_filename, batch_size=5, n_columns=2) y = tfe.iterate(x, batch_size=3, repeat=True, shuffle=False) z = tfe.iterate(x, batch_size=3, repeat=True, shuffle=True) with tfe.Session() as sess: sess.run(tfe.global_variables_initializer()) # TODO: fix this test print(sess.run(x.reveal())) print(sess.run(y.reveal())) print(sess.run(y.reveal())) print(sess.run(x.reveal())) print(sess.run(z.reveal())) os.remove(tmp_filename)
def test_iterate_private(self): tf.reset_default_graph() prot = ABY3() tfe.set_protocol(prot) def provide_input(): return tf.reshape(tf.range(0, 8), [4, 2]) # define inputs x = tfe.define_private_input('input-provider', provide_input) write_op = x.write("x.tfrecord") with tfe.Session() as sess: # initialize variables sess.run(tfe.global_variables_initializer()) # reveal result sess.run(write_op) x = tfe.read("x.tfrecord", batch_size=5, n_columns=2) y = tfe.iterate(x, batch_size=3, repeat=True, shuffle=False) z = tfe.iterate(x, batch_size=3, repeat=True, shuffle=True) with tfe.Session() as sess: sess.run(tfe.global_variables_initializer()) print(sess.run(x.reveal())) print(sess.run(y.reveal())) print(sess.run(y.reveal())) print(sess.run(x.reveal())) print(sess.run(z.reveal()))
def test_read_private(self): tf.reset_default_graph() prot = ABY3() tfe.set_protocol(prot) def provide_input(): return tf.reshape(tf.range(0, 8), [4, 2]) # define inputs x = tfe.define_private_input("input-provider", provide_input) _, tmp_filename = tempfile.mkstemp() write_op = x.write(tmp_filename) with tfe.Session() as sess: # initialize variables sess.run(tfe.global_variables_initializer()) # reveal result sess.run(write_op) x = tfe.read(tmp_filename, batch_size=5, n_columns=2) with tfe.Session() as sess: result = sess.run(x.reveal()) np.testing.assert_allclose( result, np.array(list(range(0, 8)) + [0, 1]).reshape([5, 2]), rtol=0.0, atol=0.01, ) os.remove(tmp_filename)
def test_read_private(self): tf.reset_default_graph() prot = ABY3() tfe.set_protocol(prot) def provide_input(): return tf.reshape(tf.range(0, 8), [4, 2]) # define inputs x = tfe.define_private_input('input-provider', provide_input) write_op = x.write("x.tfrecord") with tfe.Session() as sess: # initialize variables sess.run(tfe.global_variables_initializer()) # reveal result sess.run(write_op) x = tfe.read("x.tfrecord", batch_size=5, n_columns=2) with tfe.Session() as sess: result = sess.run(x.reveal()) np.testing.assert_allclose(result, np.array(list(range(0, 8)) + [0, 1]).reshape([5, 2]), rtol=0.0, atol=0.01) print("test_read_private succeeds")