/
TRPO.py
152 lines (104 loc) · 4.69 KB
/
TRPO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import List
import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch.optim import Adam
from utils import conjugate_gradient, train, get_env, bootstrap, \
normalize, estimate_advantages
# Build the environment
env, obs_shape, num_actions = get_env("CartPole-v0")
# We will need an actor to take actions in environment,
# and a critic, for estimating state value, and, therefore, advantage
# Actor takes a state and returns actions' probabilities
actor_hidden = 32
actor = nn.Sequential(nn.Linear(obs_shape[0], actor_hidden),
nn.ReLU(),
nn.Linear(actor_hidden, num_actions),
nn.Softmax(dim=1))
# Critic takes a state and returns its values
critic_hidden = 32
critic = nn.Sequential(nn.Linear(obs_shape[0], critic_hidden),
nn.ReLU(),
nn.Linear(critic_hidden, 1))
critic_optimizer = Adam(critic.parameters(), lr=0.005)
# Critic will be updated to give more accurate advantages
def update_critic(advantages):
loss = .5 * (advantages ** 2).mean() # MSE
critic_optimizer.zero_grad()
loss.backward()
critic_optimizer.step()
# Actor decides what action to take
def get_action(state: List[float]) -> int:
state = torch.tensor(state).float().unsqueeze(0) # Turn it into a batch with a single element
probs = actor(state)
if torch.any(torch.isnan(probs)):
for p in actor.parameters():
print(p)
action = Categorical(probs=probs).sample()
return action.item()
def flat_grad(y, x, retain_graph=False, create_graph=False):
if create_graph:
retain_graph = True
g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
g = torch.cat([t.view(-1) for t in g])
return g
def HVP(df, v, x):
return flat_grad(df @ v, x, retain_graph=True)
delta = 0.01
iterations = 10
def line_search(step, criterion, alpha=0.9, max_iterations=10):
i = 0
while not criterion((alpha ** i) * step) and i < max_iterations:
i += 1
def apply_update(grad_flattened):
n = 0
for p in actor.parameters():
numel = p.numel()
g = grad_flattened[n:n + numel].view(p.shape)
p.data += g
n += numel
def surrogate_loss(new_probabilities, old_probabilities, advantages):
return (new_probabilities / old_probabilities * advantages).mean()
def kl_div(p, q):
p = p.detach()
return (p * (p.log() - q.log())).sum(-1).mean()
# Our main training function
def update_agent(rollouts: List[Rollout]) -> None:
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(critic, states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = normalize(torch.cat(advantages, dim=0).flatten())
update_critic(advantages)
distribution = actor(states)
distribution = torch.distributions.utils.clamp_probs(distribution)
probabilities = distribution[range(distribution.shape[0]), actions]
# Now we have all the data we need for the algorithm
# We will calculate the gradient wrt to the new probabilities (surrogate function),
# so second probabilities should be treated as a constant
L = surrogate_loss(probabilities, probabilities.detach(), advantages)
KL = kl_div(distribution, distribution)
parameters = list(actor.parameters())
g = flat_grad(L, actor.parameters(), retain_graph=True)
d_kl = flat_grad(KL, parameters, create_graph=True) # Create graph, because we will call backward() on it (for HVP)
def HVP(v):
return flat_grad(d_kl @ v, parameters, retain_graph=True)
search_dir = conjugate_gradient(HVP, g)
max_length = torch.sqrt(2 * delta / (search_dir @ HVP(search_dir)))
max_step = max_length * search_dir
def criterion(step):
apply_update(step)
with torch.no_grad():
distribution_new = actor(states)
distribution_new = torch.distributions.utils.clamp_probs(distribution_new)
probabilities_new = distribution_new[range(distribution_new.shape[0]), actions]
L_new = surrogate_loss(probabilities_new, probabilities, advantages)
KL_new = kl_div(distribution, distribution_new)
L_improvement = L_new - L
if L_improvement > 0 and KL_new <= delta:
return True
apply_update(-step)
return False
line_search(max_step, criterion, max_iterations=10)
# Train using our get_action() and update() functions
train(env, get_action, update_agent, num_rollouts=10, render_frequency=None, print_frequency=10,
plot_frequency=None, epochs=1000)