/
dqn_cartpole.py
79 lines (60 loc) · 2.26 KB
/
dqn_cartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import pandas as pd
import gym
import json
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory
from Environment import Env, status, actions
# ENV_NAME = 'CartPole-v0'
ENV_NAME = 'SQ-v0'
# Get the environment and extract the number of actions.
# env = gym.make(ENV_NAME)
env = Env()
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n
# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())
# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=20,
target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
# Okay, now it's time to learn something! We visualize the training here for show, but this
# slows down training quite a lot. You can always safely abort the training prematurely using
# Ctrl + C.
env.is_train = True
dqn.load_weights('dqn_{}_weights.h5f'.format(ENV_NAME))
dqn.fit(env, nb_steps=100000, visualize=False, verbose=2)
# After training is done, we save the final weights.
dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True)
with open('dqn_action.json', 'w') as fw:
observation = status.tolist()
action = [float(actions[dqn.forward(np.array([obs]))]) for obs in observation]
json.dump({'observation': observation, 'action': action}, fw)
state_batch = status.reshape([-1, 1, 1])
q_val = pd.DataFrame(dqn.compute_batch_q_values(state_batch))
q_val.to_csv('dqn_qvalue.csv')
env.is_train = False
env.plot_row = 1
env.plot_col = 5
# Finally, evaluate our algorithm for 5 episodes.
dqn.test(env, nb_episodes=5, visualize=True)
env.plt.ioff()
env.plt.show()