Пример #1
0
def train(env,
          n_state,
          n_action,
          verbose=True,
          max_iteration=100,
          max_memory=5000,
          memory_size=1000,
          batch_size=50):
    state = env.reset()
    agent: CartPoleAgent = CartPoleAgent(n_state, n_action)
    ''' Replay memory '''
    replay_memory = []
    train_memory = []

    scores = []
    losses = []
    ''' Training loop '''
    for i in range(max_iteration):
        memory, avg_score = collect_memroy(env, agent, memory_size, verbose)
        print(f'[{i+1}] avg_score = {avg_score}')
        # if avg_score >= 500:
        #     break
        ''' Push new memory '''
        replay_memory.extend(memory)
        if len(replay_memory) > max_memory:
            replay_memory = replay_memory[len(replay_memory) - max_memory:]
        ''' Random sampling'''
        train_memory = random.sample(replay_memory, memory_size)
        ''' Train the model '''
        agent.train_model_from_memory(train_memory, batch_size)
        agent.save_model(WEIGHT_PATH)

    agent.save_model(WEIGHT_PATH)
    if verbose:
        env.close()
Пример #2
0
def get_agent(path):
    if path[-1] != '/':
        path += '/'

    matching_files = list(glob.glob(path + '*.q'))

    if len(matching_files) == 0:
        return

    filename = sorted(matching_files)[0]
    agent = CartPoleAgent.load(filename)

    return agent
Пример #3
0
def train(env, n_state, n_action, verbose=True, max_iteration=500):
    state = env.reset()
    agent: CartPoleAgent = CartPoleAgent(n_state, n_action)

    scores = []
    losses = []
    ''' Training loop '''
    for i in range(max_iteration):
        state = env.reset()
        score, loss, t = 0, 0, 0
        done = False
        while not done:
            if verbose:
                env.render()
            ''' Get actions based on agent NN '''
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            score += reward
            reward = 0.1 if not done or score >= 500 else -1
            ''' Train the model through gradient descent '''
            loss += agent.train_model(state, action, reward, next_state, done)
            t += 1

            state = next_state
            if done:
                break
        ''' Collect score & loss data'''
        print(f'[{i+1}] score = {score} | avg_loss {loss/t}')
        scores.append(score)
        losses.append(loss / t)
        ''' If max score occurs in last 10 times, finish the training loop '''
        if len(scores) > 9 and (sum(scores[-10:]) / 10 >= 500):
            break
        ''' Save current model & Draw the score & loss plot '''
        if i % 50 == 0 and i > 0:
            agent.save_model(WEIGHT_PATH)
            draw_plot(scores, losses)
    env.close()
    ''' Save well trained model & Draw the result score & loss plot '''
    agent.save_model(WEIGHT_PATH)
    draw_plot(scores, losses)
def test(env, n_state, n_action, verbose=True, n_trajectories=20):
    state = env.reset()
    agent: CartPoleAgent = CartPoleAgent(n_state, n_action)
    agent.load_model(WEIGHT_PATH)
    agent.set_eval()
    trajectories = []
    recorded = 0

    while recorded < n_trajectories:
        state = env.reset()
        score, t = 0, 0
        done = False
        trajectory = []
        while not done:
            if verbose:
                env.render()
            ''' Get actions based on agent NN '''
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            trajectory.append(list(state) + [action])

            score += reward
            t += 1
            state = next_state
            if done:
                break
        print(f'Score = {score}')
        if score >= 500:
            trajectories.append(trajectory)
            recorded += 1
            print(f'Recorded {recorded}/{n_trajectories}')

    env.close()

    for i in range(n_trajectories):
        with open(TRAJECTORY_PATH + f'CartPole_trajectory_{i}.csv', 'w') as f:
            curr_trajectory = trajectories[i]
            for step in curr_trajectory:
                str_step = tuple(map(lambda x: str(x), step))
                f.write(','.join(str_step) + '\n')
            f.close()
Пример #5
0
def test(env, n_state, n_action, verbose=True, max_iteration=10):
    state = env.reset()
    agent: CartPoleAgent = CartPoleAgent(n_state, n_action)
    agent.load_model(WEIGHT_PATH)
    agent.set_eval()

    for i in range(max_iteration):
        state = env.reset()
        score, t = 0, 0
        done = False
        while not done:
            if verbose:
                env.render()
            ''' Get actions based on agent NN '''
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            score += reward
            t += 1
            state = next_state
            if done:
                break
        print(f'[{i+1}] score = {score}')
    env.close()
Пример #6
0
    def test_model_saving(self):
        observation = self.env.reset()
        action = self.agent.get_action(observation)
        observation, reward, _, _ = self.env.step(action)
        prev_observation = observation
        prev_action = action

        self.agent.update(prev_observation, prev_action, reward, observation)

        model_path = self.agent.save()
        prev_q_table = self.agent.q_table
        del self.agent

        self.agent = CartPoleAgent.load(model_path)
        assert str(self.agent.q_table) == str(prev_q_table)

        # Cleanup
        try:
            os.remove(model_path)
            os.rmdir(self.agent.model_path)
        except OSError:
            print(
                'WARNING: Could not delete test model, or its containing directory.'
            )
Пример #7
0
import gym
from agent import CartPoleAgent

environment = gym.make('CartPole-v1')
learning_agent = CartPoleAgent(environment)


def train_agent(env, agent, n_episodes=1000):
    for ep in range(n_episodes):
        evaluate_episode(env, agent)
        # update agent
    env.close()
    return agent


def evaluate_episode(env, agent, n_steps=1000):
    observation = env.reset()
    for i in range(n_steps):
        env.render()
        action = agent.act(observation)
        observation, reward, done, info = env.step(action)
        if done:
            env.reset()


trained_agent = train_agent(environment, learning_agent)
Пример #8
0
 def setup(self):
     self.env = make('CartPole-v0')
     self.agent = CartPoleAgent(self.env.action_space,
                                self.env.observation_space)
Пример #9
0
class TestAgent(unittest.TestCase):
    def setup(self):
        self.env = make('CartPole-v0')
        self.agent = CartPoleAgent(self.env.action_space,
                                   self.env.observation_space)

    def teardown(self):
        self.env.close()
        del self.agent

    @test
    def test_bucketing(self):
        observation = self.env.reset()

        assert min(self.agent.bucketer(observation)) != -1

        for value in self.agent.bucketer(observation):
            assert 0 <= value <= self.agent.bucketer.n_buckets

    @test
    def test_valid_selection(self):
        observation = self.env.reset()

        action = self.agent.get_action(observation)
        assert self.env.action_space.contains(action)

    @test
    def test_updates_q_values(self):
        observation = self.env.reset()
        action = self.agent.get_action(observation)

        prev_q_table = str(self.agent.q_table)
        prev_observation = observation
        prev_action = action

        observation, reward, _, _ = self.env.step(action)

        self.agent.update(prev_observation, prev_action, reward, observation)
        assert str(self.agent.q_table
                   ) != prev_q_table, 'Q table unchanged:\n{}\nVS\n{}'.format(
                       self.agent.q_table, prev_q_table)

    @test
    def test_model_saving(self):
        observation = self.env.reset()
        action = self.agent.get_action(observation)
        observation, reward, _, _ = self.env.step(action)
        prev_observation = observation
        prev_action = action

        self.agent.update(prev_observation, prev_action, reward, observation)

        model_path = self.agent.save()
        prev_q_table = self.agent.q_table
        del self.agent

        self.agent = CartPoleAgent.load(model_path)
        assert str(self.agent.q_table) == str(prev_q_table)

        # Cleanup
        try:
            os.remove(model_path)
            os.rmdir(self.agent.model_path)
        except OSError:
            print(
                'WARNING: Could not delete test model, or its containing directory.'
            )
Пример #10
0
    n_rows, n_cols = df.values.shape

    for row in range(n_rows):
        text.append([])

        for col in range(n_cols):
            text[row].append('{:02.4f}'.format(
                df.values[row, col]).rstrip('0').rstrip('.'))

    ax1.table(cellText=text, colLabels=df.columns, loc='center')

    ax2 = plt.subplot(122)
    img = ax2.imshow(df.drop('observation', axis=1), cmap='hot_r')
    plt.colorbar(img, ax=ax2)

    step = int(log(len(df)))
    ticks = [i for i in range(0, len(df), step)]
    ax2.set_yticks(ticks)
    ax2.set_yticklabels(df['observation'][df.index % step == 0])

    ax2.set_ylabel('Bucketed Observation')
    ax2.set_xlabel('Aciton')
    ax2.set_title('Heatmap of values in Q-Table')

    plt.tight_layout()
    plt.show()


agent = CartPoleAgent.load(args.path)
plot_qtable(agent)
Пример #11
0
    dashboard = Dashboard(ema_alpha=1e-2, real_time=args.live_plot)

# Setup logger
logger = Logger(verbosity=args.log_verbosity, filename_prefix=args.model_name)
logger.log('episode_info', 'episode,timesteps')
logger.log('learning_rate', 'learning_rate')
logger.log('exploration_rate', 'exploration_rate')

# Load OpenAI Gym and agent.
env = gym.make('CartPole-v0')

model_filename = args.model_name + '.q'
checkpoint_filename_format = args.model_name + '-checkpoint-{:03d}.q'

if args.model_path:
    agent = CartPoleAgent.load(args.model_path)
    args.model_name = Path(args.model_path).name
    agent.model_path = get_run_path(prefix='data/')
else:
    agent = CartPoleAgent(env.action_space,
                          env.observation_space,
                          n_buckets=6,
                          learning_rate=1,
                          learning_rate_annealing=ExponentialDecay(k=1e-3),
                          exploration_rate=1,
                          exploration_rate_annealing=Step(k=2e-2,
                                                          step_after=100),
                          discount_factor=0.9,
                          input_mask=[0, 1, 1, 1])

start = time.time()