-
Notifications
You must be signed in to change notification settings - Fork 0
/
play.py
executable file
·136 lines (100 loc) · 3 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
from train import Net, update_model, predict
import sys
import numpy as np
import torch
import random
import connect_four
from datetime import datetime
from connect_four import RED, YELLOW
def request_move(model, board, chance=2):
use_random = random.random() < chance
if use_random:
prediction = torch.rand(7)
return prediction
else:
return predict(model, board)
def save_model(name, model):
current = datetime.now().strftime("%H-%M-%S")
PATH = './nets/model-%s-%s.pth' % (name, current)
torch.save(model.state_dict(), PATH)
def view_parameters(layer):
print(list(layer.parameters()))
def play(net_one, net_two):
# player foo and bar locked in eternal battle
foo_moves = []
bar_moves = []
foo_preds = []
bar_preds = []
foo_boards = []
bar_boards = []
foos_turn = True
game = connect_four.Game()
# move
while not game.won():
board = game.to_tensor(foos_turn)
if foos_turn:
prediction = request_move(net_one, board)
else:
prediction = request_move(net_two, board)
max_value = torch.max(prediction)
move = np.zeros((1, 7))
for i, element in enumerate(prediction):
if element == max_value:
move[:, i] = 1
prediction[i] = 0.
column = move.tolist()[0].index(1.)
for iter in range(7):
try:
game.insert(column, RED if foos_turn else YELLOW)
break
except Exception as e:
column = (column + 1) % 7
else:
# game.print_board()
return
# switch turn
if foos_turn:
foo_moves.append(move)
foo_preds.append(prediction)
foo_boards.append(board)
foos_turn = False
else:
bar_moves.append(move)
bar_preds.append(prediction)
bar_boards.append(board)
foos_turn = True
game.print_board()
# if its no longer foos turn, he won
good_moves = foo_moves if not foos_turn else bar_moves
good_preds = foo_preds if not foos_turn else bar_preds
good_boards = foo_boards if not foos_turn else bar_boards
bad_moves = foo_moves if foos_turn else bar_moves
bad_moves = [(move - 1) * -1 for move in bad_moves]
bad_preds = foo_preds if foos_turn else bar_preds
bad_boards = foo_boards if foos_turn else bar_boards
# send moves and predictions to update the model
if not foos_turn:
update_model(net_one, good_boards, good_moves, good_preds)
update_model(net_two, bad_boards, bad_moves, bad_preds)
else:
update_model(net_two, good_boards, good_moves, good_preds)
update_model(net_one, bad_boards, bad_moves, bad_preds)
if __name__ == '__main__':
games = 1000
# net = Net()
games = int(sys.argv[1])
if len(sys.argv) > 3:
PATH_ONE = sys.argv[2]
PATH_TWO = sys.argv[3]
net_one = Net()
net_one.load_state_dict(torch.load(PATH_ONE))
net_two = Net()
net_two.load_state_dict(torch.load(PATH_TWO))
else:
net_one = Net()
net_two = Net()
for game_num in range(1,games+1):
play(net_one, net_two)
save_model('one', net_one)
save_model('two', net_two)