Exemplo n.º 1
0
import latticex.rosetta as rtt
import tensorflow as tf
import sys, os
import numpy as np
np.set_printoptions(suppress=True)
tf.set_random_seed(0)

protocol = "SecureNN"
if "ROSETTA_TEST_PROTOCOL" in os.environ.keys():
    print("***** test_cases use ", os.environ["ROSETTA_TEST_PROTOCOL"])
    protocol = os.environ["ROSETTA_TEST_PROTOCOL"]
else:
    print("***** test_cases use default helix protocol ")
rtt.activate(protocol)

print("rtt.get_protocol_name():", rtt.get_protocol_name())
patyid = rtt.get_party_id("")

###############################
# 1-d
###############################
print('=================1-d=========================')
x = [1, 8, 3]
print(x)

a = tf.Variable(x, dtype=tf.float64)
aa = tf.reduce_mean(a)
a0 = tf.reduce_mean(a, axis=0)

aa_k = tf.reduce_mean(a, keepdims=True)
a0_k = tf.reduce_mean(a, axis=0, keepdims=True)
Exemplo n.º 2
0
rtt.set_backend_loglevel(2)

protocol = "SecureNN"

if "ROSETTA_TEST_PROTOCOL" in os.environ.keys():
    print("***** test_cases use ", os.environ["ROSETTA_TEST_PROTOCOL"])
    protocol = os.environ["ROSETTA_TEST_PROTOCOL"]
else:
    print("***** test_cases use default helix protocol ")

rtt.activate(protocol)

patyid = rtt.get_party_id()
print(
    "rtt.get_protocol_name: ",
    rtt.get_protocol_name(),
    "party: ",
)

# float * tf.Variable()
num_a = np.ones([10, 10])
num_b = np.ones([10, 10]) * 2
num_neg_a = -num_a

# only ALICE could own private data
X = tf.Variable(rtt.private_input(0, num_a))
Y = tf.Variable(rtt.private_input(0, num_b))
P = tf.Variable(rtt.private_input(0, num_neg_a))
CX = tf.constant(num_a)
CY = tf.constant(num_b)