-
Notifications
You must be signed in to change notification settings - Fork 0
/
project3.py
98 lines (83 loc) · 3.25 KB
/
project3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from collections import OrderedDict
from hopfieldnet import HopfieldNet
from sampleset import SampleSet
def show_menu():
""" Show menu to user """
while True:
print('_' * 30)
print('0) Enter 0 to quit.')
print('1) Enter 1 to train a net.')
print('2) Enter 2 to test a network.')
selection = input('Select: ')
if selection == '1':
print('Enter the filename that contains the samples to store.')
samples_filename = input('>>> ')
print('Enter a filename to store weights.')
weights_filename = input('>>> ')
train_network(samples_filename, weights_filename)
print('Training Network... OK')
elif selection == '2':
print('Enter the filename that contains the weights settings.')
weights_filename = input('>>> ')
print('Enter a filename that contains the testing samples.')
testing_filename = input('>>> ')
print('Enter a filename to store results.')
results_filename = input('>>> ')
results = test_network(weights_filename, testing_filename)
print('Testing Network... OK')
save_results_to_file(results, results_filename)
print('Saving results to "{}"... OK'.format(results_filename))
elif selection =='0':
break
def train_network(samples_filename, weights_filename):
""" Perform training on network from file and saved weights. """
sample_set = SampleSet()
sample_set.init_from_file(samples_filename)
net = HopfieldNet(sample_set.sample_size)
net.initialize()
for sample in sample_set:
net.train(sample)
save_weights_to_file(net.weights, weights_filename)
def test_network(weights_filename, samples_filename):
""" Perform testing on network from weights file."""
weights = read_weights_from_file(weights_filename)
net = HopfieldNet(len(weights), weights)
sample_set = SampleSet()
sample_set.init_from_file(samples_filename)
results = []
for sample in sample_set:
results.append(net.test(sample))
return results
def save_weights_to_file(weights, filename):
""" Save weights to specified file """
with open(filename, 'w') as f:
f.write('\n'.join([" ".join([str(n) for n in weight]) for weight in weights])) #format weights rows by newline columns by space
def read_weights_from_file(filename):
""" Read weights from specified file. """
weights = []
with open(filename) as f:
raw_lines = f.readlines()
for line in raw_lines:
weights.append(list(map(int, line.strip().split())))
return weights
def save_results_to_file(results, filename):
formated_results = []
for result in results:
pattern = generate_pattern(result, 10)
formated_results.append('\n'.join(pattern))
with open(filename, 'w') as f:
f.write('\n\n'.join(formated_results))
def generate_pattern(vector, row_size):
pattern = []
row = []
count = 0
for bit in vector:
if count == 10:
pattern.append(''.join(row))
row = []
count = 0
row.append('o' if bit == 1 else '_')
count += 1
return pattern
if __name__ == '__main__':
show_menu()