示例#1
0
rtt.set_backend_loglevel(1)
np.set_printoptions(suppress=True)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
np.random.seed(0)

rtt.activate("SecureNN")
mpc_player_id = rtt.py_protocol_handler.get_party_id()
BATCH_SIZE = 100
ROW_NUM = 5500

# real data
# ######################################## difference from tensorflow
file_x = '../dsets/P' + str(mpc_player_id) + "/mnist_train_x.csv"
file_y = '../dsets/P' + str(mpc_player_id) + "/mnist_train_y.csv"
X_train_0 = rtt.PrivateTextLineDataset(file_x, data_owner=0)
X_train_1 = rtt.PrivateTextLineDataset(file_x, data_owner=1)
Y_train = rtt.PrivateTextLineDataset(file_y, data_owner=1)
# ######################################## difference from tensorflow

cache_dir = "./temp{}".format(mpc_player_id)
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir, exist_ok=True)
else:
    # fix TF1.14 cache file bug
    import shutil
    shutil.rmtree(cache_dir)
    os.makedirs(cache_dir, exist_ok=True)


# dataset decode
示例#2
0
file_x = ""
file_y = ""
filex_name = "cls_train_x.csv"
filey_name = "cls_train_y.csv"

file_x = "../dsets/P" + str(mpc_player_id) + "/" + filex_name
file_y = "../dsets/P" + str(mpc_player_id) + "/" + filey_name

print("file_x:", file_x)
print("file_y:", file_y)
print("DIM_NUM:", DIM_NUM)


# training dataset
dataset_x0 = rtt.PrivateTextLineDataset(
    file_x, data_owner=0)  # P0 hold the file_x data
dataset_x1 = rtt.PrivateTextLineDataset(
    file_x, data_owner=1)  # P1 hold the file_x data
dataset_y = rtt.PrivateTextLineDataset(
    file_y, data_owner=0)  # P0 hold the file_y data


# dataset decode
def decode_p0(line):
    fields = tf.string_split([line], ',').values
    fields = rtt.PrivateInput(fields, data_owner=0)
    return fields


def decode_p1(line):
    fields = tf.string_split([line], ',').values
def decode_p0(line):
    fields = tf.string_split([line], ',').values
    fields = rtt.PrivateInput(fields, data_owner=0)
    # tf.print(fields)
    return fields


def decode_p1(line):
    fields = tf.string_split([line], ',').values
    fields = rtt.PrivateInput(fields, data_owner=1)
    # tf.print(fields)
    return fields


dataset_x = rtt.PrivateTextLineDataset(file_x, data_owner=0)  # owner is p0
dataset_y = rtt.PrivateTextLineDataset(file_y, data_owner=1)  # owner is p1

dataset_x = dataset_x\
    .map(decode_p0)\
    .batch(batch_size)

dataset_y = dataset_y\
    .map(decode_p1)\
    .batch(batch_size)

iter_x = dataset_x.make_initializable_iterator()
iter_y = dataset_y.make_initializable_iterator()

v = tf.Variable(rtt.private_input(0, np.ones([4, 1])))