Пример #1
0
 def create_grapher(self):
     self.grapher = Grapher(losses_path)
Пример #2
0
class OptimizerHandler:
    def __init__(self, match_handler, batch_size, learning_rate,
                 n_train_per_game, min_win_rate):
        self.match_handler = match_handler
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.n_train_per_game = n_train_per_game
        self.min_win_rate = min_win_rate

        self.create_grapher()
        self.create_optim()
        self.MSE = nn.MSELoss()

    def create_grapher(self):
        self.grapher = Grapher(losses_path)

    def create_optim(self):
        self.optimizer = optim.SGD(
            self.match_handler.agent_mcts.network.parameters(),
            weight_decay=0.001,
            lr=self.learning_rate,
            momentum=0.9)
        self.eps = torch.FloatTensor([1e-8]).to(device)

    def optimize_model(self):
        '''Applies one iteration of SGD on agent network (if batch size is sufficiently large),
           target is sampled from experience replay'''
        model = self.match_handler.agent_mcts.network
        experience_replay = self.match_handler.experience_replay
        model.train()
        if len(experience_replay.deque) > self.batch_size:
            samples = experience_replay.sample(self.batch_size)
            gt_probs, gt_v, s = zip(*samples)
            gt_probs = torch.stack(gt_probs)
            gt_v = torch.stack(gt_v)
            s = torch.stack(s)

            self.optimizer.zero_grad()
            probs, v = model(s)
            loss = self.MSE(v, gt_v) - torch.mean(
                gt_probs * torch.log(probs + self.eps))
            loss.backward()
            self.optimizer.step()

            self.grapher.write(str(loss.data.cpu().numpy()))
        model.eval()

    def train(self, mcts_steps, eps):
        '''Training loop, one self-play match and self.n_train_per_game SGD iterations'''
        for i in range(10000000):
            self.replace_opponent_if_needed(i)
            self.match_handler.play_match(mcts_steps, eps)
            for _ in range(self.n_train_per_game):
                self.optimize_model()

    def replace_opponent_if_needed(self, i):
        '''Checks if the current winrate is above a certain threshold.
           If yes, updates opponent and saves agent'''
        n_eval = self.match_handler.n_eval
        n_wins = self.match_handler.n_latest_wins
        n_losses = self.match_handler.n_latest_losses
        n_draws = self.match_handler.n_latest_draws
        if len(self.match_handler.deque_latest_results) >= n_eval:
            win_rate = n_wins / (n_wins + n_losses)
            if i % 10 == 0:
                print("iter", i, "win_rate:", win_rate, "wins:", n_wins,
                      "losses:", n_losses, "draws:", n_draws)
            if win_rate > self.min_win_rate:
                self.save_model_and_update_opponent()
                self.match_handler.create_results_tracker()
                print("\nwin rate reached", win_rate, "saving model\n")

    def save_model_and_update_opponent(self):
        '''Saves agent network to file and loads it into opponent'''
        agent_network = self.match_handler.agent_mcts.network
        opponent_network = self.match_handler.opponent_mcts.network
        torch.save(agent_network.state_dict(), model_path)
        opponent_network.load_state_dict(torch.load(model_path))
Пример #3
0
class OptimizerHandler:
    def __init__(self, agent, batch_size, learning_rate, n_train_per_solve,
                 n_eval, min_mcts_steps, n_shuffle):
        self.agent = agent
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.n_train_per_solve = n_train_per_solve
        self.n_eval = n_eval
        self.min_mcts_steps = min_mcts_steps
        self.n_shuffle = n_shuffle

        self.create_optim()
        self.create_grapher()
        self.create_results_tracker()

        self.MSE = nn.MSELoss()

    def create_results_tracker(self):
        self.average_mcts_steps = -1
        self.average_solve_time = -1
        self.deque_latest_results = collections.deque(maxlen=self.n_eval)

    def create_grapher(self):
        self.grapher = Grapher("save_dir/loss_folder/losses_n" +
                               str(self.n_shuffle) + ".txt")

    def create_optim(self):
        self.optimizer = optim.SGD(self.agent.mcts.network.parameters(),
                                   weight_decay=0.0001,
                                   lr=self.learning_rate,
                                   momentum=0.9)

    def optimize_model(self):
        model = self.agent.mcts.network
        experience_replay = self.agent.experience_replay
        model.train()
        if len(experience_replay.deque) > self.batch_size:
            samples = experience_replay.sample(self.batch_size)
            s, target = zip(*samples)
            s = torch.stack(s)
            target = torch.stack(target)
            self.optimizer.zero_grad()
            loss = self.MSE(model(s), target)
            loss.backward()
            self.optimizer.step()

            self.grapher.write(str(loss.data.cpu().numpy()))
        model.eval()

    def save_model_and_reset_grapher(self):
        agent_network = self.agent.mcts.network
        torch.save(agent_network.state_dict(), model_path)
        with open(experience_path, "wb") as f:
            pickle.dump(self.agent.experience_replay, f,
                        pickle.HIGHEST_PROTOCOL)
        self.create_grapher()

    def train(self, eps):
        for i in range(10000000):
            self.check_if_increase_n_shuffle(i)
            self.attempt_random_cube(eps)
            for _ in range(self.n_train_per_solve):
                self.optimize_model()

    def check_if_increase_n_shuffle(self, i):
        if i % 10 == 0:
            print("it", i, "avg steps", self.average_mcts_steps, "avg time",
                  self.average_solve_time, "n_eval",
                  len(self.deque_latest_results))
        if len(self.deque_latest_results) >= self.n_eval:
            if self.average_mcts_steps < self.min_mcts_steps:
                self.n_shuffle += 1
                print("\n\n", "mcts length reached", self.average_mcts_steps)
                print("saving model and increasing nshuffle to",
                      self.n_shuffle, "\n\n")
                self.create_results_tracker()
                self.save_model_and_reset_grapher()

    def attempt_random_cube(self, eps):
        #solve cube
        self.agent.mcts.reset(self.n_shuffle)
        solve_time = time.time()
        for mcts_steps in range(10000):
            a = self.agent.mcts.monte_carlo_tree_search(1, eps)
            terminate = self.agent.mcts.root.solution_found
            if terminate:
                break
            if mcts_steps == 9999:
                print("reached max search steps")
        terminate = self.agent.mcts.root.is_terminate_state
        for _ in range(min(40, int(self.n_shuffle * 1.5))):
            a = self.agent.mcts.get_best_action(eps=0)
            terminate = self.agent.mcts.change_root_with_action(a)
            if terminate:
                break
        solve_time = time.time() - solve_time
        self.traverse_and_add_to_replay()

        #update eval statistics
        n = len(self.deque_latest_results)
        if n == self.n_eval:
            first_mcts_steps, first_time = self.deque_latest_results.popleft()
            self.average_mcts_steps = (n * self.average_mcts_steps -
                                       first_mcts_steps) / (n - 1)
            self.average_solve_time = (n * self.average_solve_time -
                                       first_time) / (n - 1)
        self.deque_latest_results.append([mcts_steps, solve_time])
        self.average_mcts_steps = (n * self.average_mcts_steps +
                                   mcts_steps) / (n + 1)
        self.average_solve_time = (n * self.average_solve_time +
                                   solve_time) / (n + 1)

    def traverse_and_add_to_replay(self):
        node = self.agent.mcts.root.parent
        while node.parent != None:
            node = node.parent
            self.agent.add_to_experience_replay(node)
        self.agent.add_to_experience_replay(node)
Пример #4
0
 def create_grapher(self):
     self.grapher = Grapher("save_dir/loss_folder/losses_n" +
                            str(self.n_shuffle) + ".txt")
Пример #5
0
class OptimizerHandler:
    def __init__(self, agent, batch_size, learning_rate, n_train_per_solve,
                 n_eval, min_solve_rate, n_shuffle):
        self.agent = agent
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.n_train_per_solve = n_train_per_solve
        self.n_eval = n_eval
        self.min_solve_rate = min_solve_rate
        self.n_shuffle = n_shuffle

        self.create_optim()
        self.create_grapher()
        self.create_results_tracker()

        self.MSE = nn.MSELoss()

    def create_results_tracker(self):
        self.n_latest_wins = 0
        self.n_latest_losses = 0
        self.deque_latest_results = collections.deque(maxlen=self.n_eval)

    def create_grapher(self):
        self.grapher = Grapher("save_dir/loss_folder/losses_n" +
                               str(self.n_shuffle) + ".txt")

    def create_optim(self):
        self.optimizer = optim.SGD(self.agent.mcts.network.parameters(),
                                   weight_decay=0.0001,
                                   lr=self.learning_rate,
                                   momentum=0.9)

    def optimize_model(self):
        model = self.agent.mcts.network
        experience_replay = self.agent.experience_replay
        model.train()
        if len(experience_replay.deque) > self.batch_size:
            samples = experience_replay.sample(self.batch_size)
            s, target = zip(*samples)
            s = torch.stack(s)
            target = torch.stack(target)
            self.optimizer.zero_grad()
            loss = self.MSE(model(s), target)
            loss.backward()
            self.optimizer.step()

            self.grapher.write(str(loss.data.cpu().numpy()))
        model.eval()

    def save_model_and_reset_grapher(self):
        agent_network = self.agent.mcts.network
        torch.save(agent_network.state_dict(), model_path)
        with open(experience_path, "wb") as f:
            pickle.dump(self.agent.experience_replay, f,
                        pickle.HIGHEST_PROTOCOL)
        self.create_grapher()

    def train(self, max_mcts_steps, mcts_eps, final_choose_eps):
        for i in range(10000000):
            self.check_if_increase_n_shuffle(i)
            self.attempt_random_cube(max_mcts_steps, mcts_eps,
                                     final_choose_eps)
            for _ in range(self.n_train_per_solve):
                self.optimize_model()

    def check_if_increase_n_shuffle(self, i):
        if len(self.deque_latest_results) >= self.n_eval:
            solve_rate = self.n_latest_wins / self.n_eval
            if i % 10 == 0:
                print(i, "solve_rate", solve_rate)
            if solve_rate > self.min_solve_rate:
                self.n_shuffle += 1
                self.create_results_tracker()
                self.save_model_and_reset_grapher()
                print("solve rate reached", solve_rate)
                print("saving model and increasing nshuffle to",
                      self.n_shuffle)

    def attempt_random_cube(self, max_mcts_steps, mcts_eps, final_choose_eps):
        #solve cube
        self.agent.mcts.reset(self.n_shuffle)
        for _ in range(min(35, int(self.n_shuffle * 1.5))):
            a = self.agent.mcts.monte_carlo_tree_search(
                max_mcts_steps, mcts_eps, final_choose_eps)
            terminate = self.agent.mcts.change_root_with_action(a)
            if terminate:
                break
        self.traverse_and_add_to_replay()

        #update eval statistics
        if len(self.deque_latest_results) == self.n_eval:
            first = self.deque_latest_results.popleft()
            if first == True:
                self.n_latest_wins -= 1
            else:
                self.n_latest_losses -= 1

        self.deque_latest_results.append(terminate)
        if terminate == True:
            self.n_latest_wins += 1
        else:
            self.n_latest_losses += 1

    def traverse_and_add_to_replay(self):
        node = self.agent.mcts.root.parent
        while node.parent != None:
            node = node.parent
            self.agent.add_to_experience_replay(node)
        self.agent.add_to_experience_replay(node)