def server_routine(s2w_conns, num_processes=8): infnet = InferenceNetwork(STATES_SIZE, MASK_SIZE) done_flags = [False] * 8 dummy = azul.Azul(2) dummy.start() dummy_status = (dummy.states(), dummy.flat_mask()) while True: if all(done_flags): break states, masks = [], [] for i in range(num_processes): if done_flags[i]: state, mask = dummy_status else: state, mask, flag = s2w_conns[i].recv() if flag == True: done_flags[i] = True state, mask = dummy_status states.append(state) masks.append(mask) states = np.stack(states, axis=0) masks = np.stack(masks, axis=0) values, priors = infnet.predict(states, masks) for i in range(num_processes): if not done_flags[i]: s2w_conns[i].send((values[i], priors[i])) infnet.close()
def __init__(self, game, inference_fuction, commands): root_game = deepcopy(game) self._infnet = inference_fuction self._dummy = azul.Azul(2) self._dummy.start() root_value, root_prior = self._infnet(root_game) self._root = _MCTNode(root_game, root_prior, False) _MCTNode.commands = commands self._choice = None
def make_envs(self): self._env_dics = [] for _ in range(self._num_env): dic = { 'env':azul.Azul(2), 'envs':[], 'obs':[], 'masks':[], 'acs':[], 'probs':[], 'terminals':[], 'done':False, 'winner':0 } self._env_dics.append(dic)
def debug(): game = azul.Azul(2) game.start() commands = np.argwhere(np.ones((6, 5, 6)) == 1) inf_helper = InfHelperS() search = mcts.MCTSearch(game, inf_helper, commands) searches = [search] accumulated_data = [] winner = None while True: action_command, training_data = search.start_search(300) accumulated_data.append(training_data) is_turn_end = game.take_command(action_command) if is_turn_end: game.turn_end(verbose=False) if game.is_terminal: game.final_score() winner = game.leading_player_num break else: game.start_turn() search = mcts.MCTSearch(game, inf_helper, commands) else: # search.change_root() search = mcts.MCTSearch(game, inf_helper, commands) searches.append(search) state_data, action_data, value_data = [], [], [] for state, action_index, player in accumulated_data: state_data.append(state) action_index = str(action_index // 30) + str( (action_index % 30) // 6) + str(action_index % 6) action_data.append(action_index) if player == winner: value_data.append(1.) else: value_data.append(-1.) return state_data, action_data, value_data, searches
import tensorflow as tf import numpy as np import alphazul from copy import deepcopy import azul import policy import mcts if __name__ == '__main__': game = azul.Azul(2) game.start() # states = game.states() # mask = game.flat_mask() # inf = alphazul.InferenceNetwork(states.shape[0],mask.shape[0]) # value, prior = inf.predict([states],[mask]) # print(prior,np.sum(prior)) # print(np.sum(prior>0),np.sum(mask)) search = mcts.MCTSearch(game, policy.rollout, commands=np.argwhere(np.ones((6, 5, 6)) == 1)) while True: # game.display() # print('------------------------------------') # search._root.game.display() # print('\n\n') action, _ = search.start_search(100) is_turn_end = game.take_command(action) if is_turn_end:
def __init__(self, w2s_conn): self._w2s_conn = w2s_conn self._dummy = azul.Azul(2) self._dummy.start()
def self_play(): processes = [] s2w_conns = [] public_q = Queue() # define workers for i in range(8): game = azul.Azul(2) game.start() w2s_conn, s2w_conn = Pipe() s2w_conns.append(s2w_conn) p = Process(target=worker_routine, args=(game, w2s_conn, public_q)) processes.append(p) # define server server = Process(target=server_routine, args=(s2w_conns, )) # start process server.start() for p in processes: p.start() min_length = 999 all_data = [] for i in range(8): state_data, action_data, value_data, mask_data = public_q.get() if len(state_data) <= min_length: min_length = len(state_data) all_data.append((state_data, action_data, value_data, mask_data)) state_data_all,action_data_all,value_data_all,mask_data_all = [],[],[],[] for state_data, action_data, value_data, mask_data in all_data: data_zip = list(zip(state_data, action_data, value_data, mask_data)) random.shuffle(data_zip) state_data, action_data, value_data, mask_data = list(zip(*data_zip)) state_data_all.extend(state_data[:min_length]) action_data_all.extend(action_data[:min_length]) value_data_all.extend(value_data[:min_length]) mask_data_all.extend(mask_data[:min_length]) state_data_all = np.stack(state_data_all) action_data_all = np.stack(action_data_all) value_data_all = np.stack(value_data_all).reshape((-1, 1)) mask_data_all = np.stack(mask_data_all) assert len(state_data_all) == len(action_data_all) and len( state_data_all) == len(value_data_all) and len(state_data_all) == len( mask_data_all) permutated_index = np.random.permutation(len(state_data_all)) permutated_state = state_data_all[permutated_index] permutated_action = action_data_all[permutated_index] permutated_value = value_data_all[permutated_index] permutated_mask = mask_data_all[permutated_index] for p in processes: p.join() server.join() num_iter = len(permutated_state) // BATCH_SIZE infnet = InferenceNetwork(STATES_SIZE, MASK_SIZE) for i in range(num_iter): infnet.train(permutated_state[i * BATCH_SIZE:(i + 1) * BATCH_SIZE], permutated_action[i * BATCH_SIZE:(i + 1) * BATCH_SIZE], permutated_value[i * BATCH_SIZE:(i + 1) * BATCH_SIZE], permutated_mask[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]) print(i) infnet.close()