class EpsilonGreedy(BaseModel): def __init__( self, evaluator, epsilon, env_gen, optim=None, memory_queue=None, memory_size=20000, mem_type=None, batch_size=64, gamma=0.99, ): self.epsilon = epsilon self._epsilon = epsilon # Backup for evaluatoin self.optim = optim self.env = env_gen() self.memory_queue = memory_queue self.batch_size = batch_size self.gamma = gamma self.policy_net = copy.deepcopy(evaluator) self.target_net = copy.deepcopy(evaluator) if mem_type == "sumtree": self.memory = WeightedMemory(memory_size) else: self.memory = Memory(memory_size) def __call__(self, s): return self._epsilon_greedy(s) def q(self, s): if not isinstance(s, torch.Tensor): s = torch.from_numpy(s).long() # s = self.policy_net.preprocess(s) return self.policy_net(s) # Only get predict policiies def _epsilon_greedy(self, s): if np.random.rand() < self.epsilon: possible_moves = [i for i, move in enumerate(self.env.valid_moves()) if move] a = random.choice(possible_moves) else: weights = self.q(s).detach().cpu().numpy() # TODO maybe do this with tensors mask = ( -1000000000 * ~np.array(self.env.valid_moves()) ) + 1 # just a really big negative number? is quite hacky a = np.argmax(weights + mask) return a def load_state_dict(self, state_dict, target=False): self.policy_net.load_state_dict(state_dict) if target: self.target_net.load_state_dict(state_dict) def update(self, s, a, r, done, next_s): self.push_to_queue(s, a, r, done, next_s) self.pull_from_queue() if self.ready: self.update_from_memory() def push_to_queue(self, s, a, r, done, next_s): s = torch.tensor(s, device=device) # s = self.policy_net.preprocess(s) a = torch.tensor(a, device=device) r = torch.tensor(r, device=device) done = torch.tensor(done, device=device) next_s = torch.tensor(next_s, device=device) # next_s = self.policy_net.preprocess(next_s) self.memory_queue.put(Transition(s, a, r, done, next_s)) def pull_from_queue(self): while not self.memory_queue.empty(): experience = self.memory_queue.get() self.memory.add(experience) def update_from_memory(self): if isinstance(self.memory, WeightedMemory): tree_idx, batch, sample_weights = self.memory.sample(self.batch_size) sample_weights = torch.tensor(sample_weights, device=device) else: batch = self.memory.sample(self.batch_size) batch_t = Transition(*zip(*batch)) # transposed batch s_batch, a_batch, r_batch, done_batch, s_next_batch = batch_t s_batch = torch.cat(s_batch) a_batch = torch.stack(a_batch) r_batch = torch.stack(r_batch).view(-1, 1) s_next_batch = torch.cat(s_next_batch) done_batch = torch.stack(done_batch).view(-1, 1) q = self._state_action_value(s_batch, a_batch) # Get Actual Q values double_actions = self.policy_net(s_next_batch).max(1)[1].detach() # used for double q learning q_next = self._state_action_value(s_next_batch, double_actions) q_next_actual = (~done_batch) * q_next # Removes elements thx`at are done q_target = r_batch + self.gamma * q_next_actual ###TEST if clamping works or is even good practise q_target = q_target.clamp(-1, 1) ###/TEST if isinstance(self.memory, WeightedMemory): absolute_loss = torch.abs(q - q_target).detach().cpu().numpy() loss = weighted_smooth_l1_loss( q, q_target, sample_weights ) # TODO fix potential non-linearities using huber loss self.memory.batch_update(tree_idx, absolute_loss) else: loss = F.smooth_l1_loss(q, q_target) self.optim.zero_grad() loss.backward() for param in self.policy_net.parameters(): # see if this ends up doing anything - should just be relu param.grad.data.clamp_(-1, 1) self.optim.step() # determines when a neural net has enough data to train @property def ready(self): return len(self.memory) >= self.memory.max_size def state_dict(self): return self.policy_net.state_dict() def update_target_net(self): self.target_net.load_state_dict(self.state_dict()) def train(self, train_state=True): return self.policy_net.train(train_state) def reset(self, *args, **kwargs): self.env.reset() def _state_action_value(self, s, a): a = a.view(-1, 1) return self.policy_net(s).gather(1, a) def evaluate(self, evaluate_state=False): # like train - sets evaluate state if evaluate_state: # self._epsilon = self.epsilon self.epsilon = 0 else: self.epsilon = self._epsilon # self.evaluating = evaluate_state def play_action(self, action, player): self.env.step(action, player)
class MCTreeSearch(BaseModel): def __init__( self, evaluator, env_gen, optim=None, memory_queue=None, iterations=100, temperature_cutoff=5, batch_size=64, memory_size=200000, min_memory=20000, update_nn=True, starting_state_dict=None, ): self.iterations = iterations self.evaluator = evaluator.to(device) self.env_gen = env_gen self.optim = optim self.env = env_gen() self.root_node = None self.reset() self.update_nn = update_nn self.starting_state_dict = starting_state_dict self.memory_queue = memory_queue self.temp_memory = [] self.memory = Memory(memory_size) self.min_memory = min_memory self.temperature_cutoff = temperature_cutoff self.actions = self.env.action_space.n self.evaluating = False self.batch_size = batch_size if APEX_AVAILABLE: opt_level = "O1" if self.optim: self.evaluator, self.optim = amp.initialize( evaluator, optim, opt_level=opt_level) print("updating optimizer and evaluator") else: self.evaluator = amp.initialize(evaluator, opt_level=opt_level) print(" updated evaluator") self.amp_state_dict = amp.state_dict() print(vars(amp._amp_state)) elif APEX_AVAILABLE: opt_level = "O1" print(vars(amp._amp_state)) if self.starting_state_dict: print("laoding [sic] state dict in mcts") self.load_state_dict(self.starting_state_dict) def reset(self, player=1): base_state = self.env.reset() probs, v = self.evaluator(base_state) self._set_root(MCNode(state=base_state, v=v, player=player)) self.root_node.create_children(probs, self.env.valid_moves()) self.moves_played = 0 self.temp_memory = [] return base_state ## Ignores the inputted state for the moment. Produces the correct action, and changes the root node appropriately # TODO might want to do a check on the state to make sure it is consistent def __call__(self, s): # not using player move = self._search_and_play() return move def _search_and_play(self): final_temp = 1 temperature = 1 if self.moves_played < self.temperature_cutoff else final_temp self.search() move = self._play(temperature) return move def play_action(self, action, player): self._set_node(action) def _set_root(self, node): if self.root_node: self.root_node.active_root = False self.root_node = node self.root_node.active_root = True def _prune(self, action): self._set_root(self.root_node.children[action]) self.root_node.parent.children = [self.root_node] def _set_node(self, action): node = self.root_node.children[action] if self.root_node.children[action].n == 0: # TODO check if leaf and n > 0?? # TODO might not backup?????????? node, v = self._expand_node(self.root_node, action, self.root_node.player) node.backup(v) node.v = v self._set_root(node) def update(self, s, a, r, done, next_s): self.push_to_queue(s, a, r, done, next_s) self.pull_from_queue() if self.ready: self.update_from_memory() def pull_from_queue(self): while not self.memory_queue.empty(): experience = self.memory_queue.get() self.memory.add(experience) def push_to_queue(self, s, a, r, done, next_s): if done: for experience in self.temp_memory: experience = experience._replace( actual_val=torch.tensor(r).float().to(device)) self.memory_queue.put(experience) self.temp_memory = [] def loss(self, batch): batch_t = Move(*zip(*batch)) # transposed batch s, actual_val, tree_probs = batch_t s_batch = torch.stack(s) net_probs_batch, predict_val_batch = self.evaluator.forward(s_batch) predict_val_batch = predict_val_batch.view(-1) actual_val_batch = torch.stack(actual_val) tree_probs_batch = torch.stack(tree_probs) c = MSELossFlat(floatify=True) value_loss = c(predict_val_batch, actual_val_batch) prob_loss = -(net_probs_batch.log() * tree_probs_batch).sum() / net_probs_batch.size()[0] loss = value_loss + prob_loss return loss def update_from_memory(self): batch = self.memory.sample(self.batch_size) loss = self.loss(batch) self.optim.zero_grad() if APEX_AVAILABLE: with amp.scale_loss(loss, self.optim) as scaled_loss: scaled_loss.backward() else: loss.backward() # for param in self.evaluator.parameters(): # see if this ends up doing anything - should just be relu # param.grad.data.clamp_(-1, 1) self.optim.step() def _play(self, temp=0.05): if self.evaluating: # Might want to just make this greedy temp = temp / 20 # More likely to choose higher visited nodes play_probs = [ child.play_prob(temp) for child in self.root_node.children ] play_probs = play_probs / sum(play_probs) action = np.random.choice(self.actions, p=play_probs) self.moves_played += 1 self.temp_memory.append( Move( torch.tensor(self.root_node.state).to(device), None, torch.tensor(play_probs).float().to(device), )) return action def _expand_node(self, parent_node, action, player=1): env = self.env_gen() env.set_state(copy.copy(parent_node.state)) s, r, done, _ = env.step(action, player=player) r = r * player if done: v = r child_node = parent_node.children[action] else: probs, v = self.evaluator(s, parent_node.player) child_node = parent_node.children[action] child_node.create_children(probs, env.valid_moves()) assert child_node.children child_node.state = s return child_node, v def search(self): if self.evaluating: self.root_node.add_noise( ) # Might want to remove this in evaluation? for i in range(self.iterations): node = self.root_node while True: select_probs = [ child.select_prob if child.valid else -10000000000 for child in node.children ] # real big negative nuber action = np.argmax(select_probs + 0.000001 * np.random.rand(self.actions)) if node.children[action].is_leaf: node, v = self._expand_node(node, action, node.player) node.backup(v) node.v = v break else: node = node.children[action] if self.evaluating: self.root_node.remove_noise() # Don't think this is necessary? def load_state_dict(self, state_dict, target=False): self.evaluator.load_state_dict(state_dict) # determines when a neural net has enough data to train @property def ready(self): # Hard code value for the moment return len(self.memory) >= self.min_memory and self.update_nn def state_dict(self): return self.evaluator.state_dict() def update_target_net(self): # No target net so pass pass def deduplicate(self): self.memory.deduplicate("state", ["actual_val", "tree_probs"], Move) def train(self, train_state=True): # Sets training true/false return self.evaluator.train(train_state) def evaluate(self, evaluate_state=False): # like train - sets evaluate state self.evaluating = evaluate_state
class Q: def __init__( self, env, evaluator, lr=0.01, gamma=0.99, momentum=0.9, weight_decay=0.01, mem_type="sumtree", buffer_size=20000, batch_size=16, *args, **kwargs ): self.gamma = gamma self.env = env self.state_size = self.env.width * self.env.height self.policy_net = copy.deepcopy(evaluator) # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device) self.target_net = copy.deepcopy(evaluator) # ConvNetConnect4(self.env.width, self.env.height, self.env.action_space.n).to(device) self.policy_net.apply(init_weights) self.target_net.apply(init_weights) self.optim = torch.optim.SGD(self.policy_net.parameters(), weight_decay=weight_decay, momentum=momentum, lr=lr,) if mem_type == "sumtree": self.memory = WeightedMemory(buffer_size) else: self.memory = Memory(buffer_size) self.batch_size = batch_size def __call__(self, s, player=None): # TODO use player variable if not isinstance(s, torch.Tensor): s = torch.from_numpy(s).long() s = self.policy_net.preprocess(s) return self.policy_net(s) def state_action_value(self, s, a): a = a.view(-1, 1) return self.policy_net(s).gather(1, a) def update(self, s, a, r, done, s_next): s = torch.tensor(s, device=device) # s = self.policy_net.preprocess(s) a = torch.tensor(a, device=device) r = torch.tensor(r, device=device) done = torch.tensor(done, device=device) s_next = torch.tensor(s_next, device=device) # s_next = self.policy_net.preprocess(s_next) if not self.ready: self.memory.add(Transition(s, a, r, done, s_next)) return # Using batch memory self.memory.add(Transition(s, a, r, done, s_next)) if isinstance(self.memory, WeightedMemory): tree_idx, batch, sample_weights = self.memory.sample(self.batch_size) sample_weights = torch.tensor(sample_weights, device=device) else: batch = self.memory.sample(self.batch_size) batch_t = Transition(*zip(*batch)) # transposed batch # Get expected Q values s_batch, a_batch, r_batch, done_batch, s_next_batch = batch_t s_batch = torch.cat(s_batch) a_batch = torch.stack(a_batch) r_batch = torch.stack(r_batch).view(-1, 1) s_next_batch = torch.cat(s_next_batch) done_batch = torch.stack(done_batch).view(-1, 1) q = self.state_action_value(s_batch, a_batch) # Get Actual Q values double_actions = self.policy_net(s_next_batch).max(1)[1].detach() # used for double q learning q_next = self.state_action_value(s_next_batch, double_actions) q_next_actual = (~done_batch) * q_next # Removes elements thx`at are done q_target = r_batch + self.gamma * q_next_actual ###TEST if clamping works or is even good practise q_target = q_target.clamp(-1, 1) ###/TEST if isinstance(self.memory, WeightedMemory): absolute_loss = torch.abs(q - q_target).detach().cpu().numpy() loss = weighted_smooth_l1_loss( q, q_target, sample_weights ) # TODO fix potential non-linearities using huber loss self.memory.batch_update(tree_idx, absolute_loss) else: loss = F.smooth_l1_loss(q, q_target) self.optim.zero_grad() loss.backward() for param in self.policy_net.parameters(): # see if this ends up doing anything - should just be relu param.grad.data.clamp_(-1, 1) self.optim.step()
class Elo(): MODEL_SAVE_FILE = ".ELO_MODEL" RESULT_SAVE_FILE = ".ELO_RESULT" ELO_VALUE_SAVE_FILE = ".ELO_VALUE" ELO_CONSTANT=400 def __init__(self, env=Connect4Env): path = os.path.dirname(__file__) self.model_shelf = shelve.open(os.path.join(path, self.MODEL_SAVE_FILE)) self.result_shelf = shelve.open(os.path.join(path, self.RESULT_SAVE_FILE)) self.elo_value_shelf = shelve.open(os.path.join(path, self.ELO_VALUE_SAVE_FILE)) self.env = env self.memory = Memory() atexit.register(self._close) def _close(self): self.model_shelf.close() self.result_shelf.close() self.elo_value_shelf.close() def add_model(self, name, model_container): try: if self.model_shelf[name]: raise ValueError("Model name already in use") except KeyError: self.model_shelf[name] = model_container print(f"added model {name}") def compare_models(self, *args): combinations = itertools.combinations(args, 2) for model_1, model_2 in combinations: self._compare(model_1, model_2) def _compare(self, model_1, model_2, num_games=100): assert model_1 != model_2 assert "_" not in model_1 assert "_" not in model_2 if model_1 > model_2: key = f"{model_1}__{model_2}" swap = False else: key = f"{model_2}__{model_1}" swap = True if key in self.result_shelf: old_results = self.result_shelf[key] else: old_results = {"wins": 0, "draws": 0, "losses": 0} new_results = self._get_results(model_1, model_2) if swap: new_results_ordered = {"wins": new_results["losses"], "draws": new_results["draws"], "losses": new_results["wins"]} else: new_results_ordered = new_results total_results = {status: new_results_ordered[status] + old_results[status] for status in ("wins", "draws", "losses")} self.result_shelf[key] = total_results def _get_results(self, model_1, model_2, num_games=100): scheduler = self_play_parallel.SelfPlayScheduler(policy_container=self.model_shelf[model_1], opposing_policy_container=self.model_shelf[model_2], evaluation_policy_container=self.model_shelf[model_2], env_gen=self.env, epoch_length=num_games, initial_games=0, self_play=False, save_dir=None) _, breakdown = scheduler.compare_models() results = {status: breakdown["first"][status] + breakdown["second"][status] for status in ("wins", "draws", "losses")} print(f"{model_1} wins: {results['wins']} {model_2} wins: {results['losses']} draws: {results['draws']}") return results def calculate_elo_2(self,anchor_model="random", anchor_elo=0): k_factor = 5 models = list(self.model_shelf.keys()) model_indices = {model: i for i, model in enumerate(model for model in models if model != anchor_model)} if "elo" in self.elo_value_shelf: elo_values = self.elo_value_shelf["elo"] initial_weights = [elo_values.get(model, 0) for model in models if model != anchor_model] else: initial_weights = None self._convert_memory2(model_indices) model_qs = {model: torch.ones(1, requires_grad=True) for model in models} # q = 10^(rating/400) model_qs[anchor_model] = torch.tensor(10 ** (anchor_elo / self.ELO_CONSTANT), requires_grad=False) epoch_length = 1000 num_epochs = 200 batch_size = 32 elo_net = EloNetwork(len(models), initial_weights) optim = torch.optim.SGD([elo_net.elo_vals.weight], lr=400) for i in range(num_epochs): for j in range(epoch_length): optim.zero_grad() batch = self.memory.sample(batch_size) batch_t = result_container(*zip(*batch)) players, results = batch_t players = torch.stack(players) results = torch.stack(results) expected_results = elo_net(players) loss = elo_net.loss(expected_results, results) loss.backward() optim.step() for param_group in optim.param_groups: param_group['lr'] = param_group['lr'] * 0.99 model_elos = {model: elo_net.elo_vals.weight.tolist()[model_indices[model]] for model in models if model != anchor_model} model_elos[anchor_model] = anchor_elo self.elo_value_shelf["elo"] = model_elos print(model_elos) return model_elos def calculate_elo(self, anchor_model="random", anchor_elo=0): models = list(self.model_shelf.keys()) model_indices = {model: i for i, model in enumerate(model for model in models if model != anchor_model)} if "elo" in self.elo_value_shelf: elo_values = self.elo_value_shelf["elo"] initial_weights = [elo_values.get(model, 0) for model in models if model != anchor_model] print(initial_weights) else: initial_weights = None self._convert_memory(model_indices) model_qs = {model: torch.ones(1, requires_grad=True) for model in models} # q = 10^(rating/400) model_qs[anchor_model] = torch.tensor(10 ** (anchor_elo / self.ELO_CONSTANT), requires_grad=False) epoch_length = 1000 num_epochs = 200 batch_size = 32 elo_net = EloNetwork(len(models), initial_weights) optim = torch.optim.SGD([elo_net.elo_vals.weight], lr=400) for i in range(num_epochs): for j in range(epoch_length): optim.zero_grad() batch = self.memory.sample(batch_size) batch_t = result_container(*zip(*batch)) players, results = batch_t players = torch.stack(players) results = torch.stack(results) expected_results = elo_net(players) loss = elo_net.loss(expected_results, results) loss.backward() optim.step() for param_group in optim.param_groups: param_group['lr'] = param_group['lr'] * 0.99 model_elos = {model: elo_net.elo_vals.weight.tolist()[model_indices[model]] for model in models if model != anchor_model} model_elos[anchor_model] = anchor_elo self.elo_value_shelf["elo"] = model_elos print(model_elos) return model_elos def _convert_memory(self, model_indices): keys = list(self.result_shelf.keys()) for key in keys: model1, model2 = key.split("__") val_1 = self._onehot(model1, model_indices) val_2 = self._onehot(model2, model_indices) players = torch.stack((val_1, val_2), 1).t() results = self.result_shelf[key] result_map = {"wins": 1, "losses": 0, "draws": 0.5} for result, value in result_map.items(): for _ in range(results[result]): self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float))) def _convert_memory2(self, model_indices): keys = list(self.result_shelf.keys()) for key in keys: model1, model2 = key.split("__") val_1 = self._onehot(model1, model_indices) val_2 = self._onehot(model2, model_indices) players = torch.stack((val_1, val_2), 1).t() results = self.result_shelf[key] result_map = {"wins": 1, "losses": 0, "draws": 0.5} for result, value in result_map.items(): for _ in range(results[result]): if result == "wins": self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float))) self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float))) if result == "losses": self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float))) self.memory.add(result_container(players, torch.tensor(value, dtype=torch.float))) if result == "draws": self.memory.add(result_container(players, torch.tensor(0, dtype=torch.float))) self.memory.add(result_container(players, torch.tensor(1, dtype=torch.float))) def _onehot(self, model, model_indices): model_idx = model_indices[model] if model in model_indices else None if model_idx is not None: val = torch.nn.functional.one_hot(torch.tensor(model_idx), len(model_indices)) else: val = torch.zeros(len(model_indices), dtype=torch.long) return val def manual_play(self, model_name): model = self.model_shelf[model_name].setup() model.train(False) model.evaluate(True) manual = ManualPlay(Connect4Env(), model) manual.play() def observe(self, model_name, opponent_model_name): model = self.model_shelf[model_name].setup() opponent_model = self.model_shelf[opponent_model_name].setup() model.train(False) opponent_model.train(False) model.evaluate(True) opponent_model.evaluate(True) view = View(Connect4Env(), model,opponent_model) view.play()