-
Notifications
You must be signed in to change notification settings - Fork 0
/
NKEP.py
93 lines (68 loc) · 2.65 KB
/
NKEP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import random
from TPM import *
#input_size = 100
#hidden_node_num = 10
#weight_range = 100000
def input_generator(input_size, weight_range):
input_arr = []
for y in xrange(input_size):
input_arr.append(random.randint(-weight_range, weight_range))
return input_arr
# Given two TPMs, it gives them a common output and calls the update function only if the outputs
# are equal.
def key_exchange(TPM1, TPM2):
inputs = input_generator(TPM1.input_num, TPM1.weight_range)
out1 = TPM1.output(inputs)
out2 = TPM2.output(inputs)
if out1 == out2:
#print "outputs the same"
TPM1.hebbian_learning_rule(inputs, out1, out2)
TPM2.hebbian_learning_rule(inputs, out2, out1)
return 1
return 0
def key_exchange_one_only(TPM1, TPM2):
inputs = input_generator(TPM1.input_num, TPM1.weight_range)
out1 = TPM1.output(inputs)
out2 = TPM2.output(inputs)
if out1 == out2:
#print "outputs the same"
TPM1.hebbian_learning_rule(inputs, out1, out2)
#TPM2.hebbian_learning_rule(inputs, out2, out1)
return 1
return 0
# Takes two TPM's and runs key exchange until they have equal output for
# cutoff number of times.
# Choosing a good cutoff is very important! Too small of a cutoff and the machines
# may not be synchronized, but too large and "evesdroppers" may be able to sync as well.
def synchronize(TPM1, TPM2, cutoff):
x = 0
while x < cutoff:
if key_exchange(TPM1, TPM2) == 1: #goes until cutoff
x+=1
else:
x-=1
return 0
# Checks if weights of two TPMs are actually secure. Run this after a synchronize
# call to make sure the weights are actually synced
def check_weights(TPM1, TPM2):
for y in range(len(TPM1.weights)):
if TPM1.weights[y] != TPM2.weights[y]:
return -1
return 0
# This function actually runs the key exchange, prompting the users for input and then
# runs the key exchange until the weights are synced.
def nkep():
input_size = int(raw_input("Enter the desired number of inputs for the networks\n"))
hidden_node_num = int(raw_input("Enter the desired number of hidden nodes. Must be able to divide the input size.\n"))
weight_range = int(raw_input("What weight range would you like to use? Note: Enter only one number, the range will be made up of the positive and negative versions of that number.\n"))
cutoff = int(raw_input("Enter a cutoff for the number of correct outputs needed to be synchronized:\n"))
x = TPM(input_size, hidden_node_num, weight_range)
y = TPM(input_size, hidden_node_num, weight_range)
while check_weights(x, y) < 0:
synchronize(x, y, cutoff)
print "Weights are now synced, printing them now\n"
print "First TPM's weights"
x.print_weights()
print "\n"
print "Second TPM's weights"
y.print_weights()