def run_model(game_count=1): """ run model for game_count games """ # Make environment env = WhaleEnv( config={ 'active_player': 0, 'seed': datetime.utcnow().microsecond, 'env_num': 1, 'num_players': 5 }) # Set up agents action_num = 3 agent = SimpleAgent(action_num=action_num, player_num=5) agent_0 = RandomAgent(action_num=action_num) agent_1 = RandomAgent(action_num=action_num) agent_2 = RandomAgent(action_num=action_num) agent_3 = RandomAgent(action_num=action_num) agents = [agent, agent_0, agent_1, agent_2, agent_3] env.set_agents(agents) agent.load_pretrained() for game in range(game_count): # Generate data from the environment trajectories = env.run(is_training=False) # Print out the trajectories print('\nEpisode {}'.format(game)) i = 0 for trajectory in trajectories: print('\tPlayer {}'.format(i)) [print(t) for t in trajectory] i += 1
def train_model(max_episodes=100): """ Trains a DQN agent to play the CartPole game by trial and error :return: None """ # buffer = ReplayBuffer() # Make environment env = WhaleEnv( config={ 'active_player': 0, 'seed': datetime.utcnow().microsecond, 'env_num': 1, 'num_players': 5 }) # Set up agents action_num = 3 agent = SimpleAgent(action_num=action_num, player_num=5) agent_0 = NoDrawAgent(action_num=action_num) agent_1 = NoDrawAgent(action_num=action_num) agent_2 = NoDrawAgent(action_num=action_num) agent_3 = NoDrawAgent(action_num=action_num) # agent_train = RandomAgent(action_num=action_num) agents = [agent, agent_0, agent_1, agent_2, agent_3] # train_agents = [agent_train, agent_0, agent_1, agent_2, agent_3] env.set_agents(agents) agent.load_pretrained() min_perf, max_perf = 1.0, 0.0 for episode_cnt in range(1, max_episodes + 1): # print(f'{datetime.utcnow()} train ...') loss = agent.train( collect_gameplay_experiences(env, agents, GAME_COUNT_PER_EPISODE)) # print(f'{datetime.utcnow()} eval ...') avg_rewards = evaluate_training_result(env, agents, EVAL_EPISODES_COUNT) # print(f'{datetime.utcnow()} calc ...') if avg_rewards[0] > max_perf: max_perf = avg_rewards[0] agent.save_weight() if avg_rewards[0] < min_perf: min_perf = avg_rewards[0] print('{0:03d}/{1} perf:{2:.2f}(min:{3:.2f} max:{4:.2f})' 'loss:{5:.4f} rewards:{6:.2f} {7:.2f} {8:.2f} {9:.2f}'.format( episode_cnt, max_episodes, avg_rewards[0], min_perf, max_perf, loss[0], avg_rewards[1], avg_rewards[2], avg_rewards[3], avg_rewards[4])) # env.close() print('training end')