def train(self, transitions, is_weights=None): states, actions, rewards, dones, next_states = transitions dones_tensor = to_tensor(dones) rewards_tensor = to_tensor(rewards) states_tensor = to_tensor(states) next_states_tensor = to_tensor(next_states) batch_indices = np.arange(states.shape[0]) q_vals = self.q_net(states_tensor) selected_q_vals = q_vals[batch_indices, actions] next_q_vals = self.q_net(next_states_tensor) next_q_target_vals = self.q_target(next_states_tensor) selected_actions = next_q_vals.argmax(dim=1) target_q_vals = rewards_tensor + self.gamma * ( 1 - dones_tensor) * next_q_target_vals[batch_indices, selected_actions] td_error = (selected_q_vals - target_q_vals) if is_weights is not None: is_weights_tensor = to_tensor(is_weights) loss = torch.mean(is_weights_tensor * td_error**2) else: loss = torch.mean(td_error**2) self.q_net.zero_grad() loss.backward() self.optimizer.step() return td_error
def train(self, transitions): states, actions, rewards, dones, next_states = transitions dones_tensor = to_tensor(dones) actions_tensor = to_tensor(actions) rewards_tensor = to_tensor(rewards) states_tensor = to_tensor(states) next_states_tensor = to_tensor(next_states) q_vals = self.q_net(states_tensor, actions_tensor) next_actions = self.policy_target(next_states_tensor) next_q_target_vals = self.q_target(next_states_tensor, next_actions) target_q_vals = rewards_tensor + self.gamma * (1 - dones_tensor) * next_q_target_vals q_loss = nn.MSELoss()(q_vals, target_q_vals) self.q_net.zero_grad() q_loss.backward() self.q_optimizer.step() policy_q_vals = self.q_net(states_tensor, self.policy(states_tensor)) policy_loss = -torch.mean(policy_q_vals) self.policy.zero_grad() policy_loss.backward() self.policy_optimizer.step()
def select_action(self, obs, train=True): obs_tensor = to_tensor(obs) action = self.policy(obs_tensor).detach().numpy() if train: action = np.clip(action + np.random.randn(action.size) * self.action_noise, env.action_space.low, env.action_space.high) return action
def step(self, action, threshold=0.5): curr_state_tensor = to_tensor(self.curr_state.reshape(1, -1)) action_tensor = torch.Tensor([action]).float() next_state, r, d_sig, _ = self.forward(curr_state_tensor, action_tensor) return ( next_state[0].detach().numpy(), r[0, 0].detach().item(), d_sig[0, 0].item() > threshold, {} )
def train_model(env, model, policy, memory, eps, train_eps, batch_size, lr): obs = env.reset() model.init_states.store(obs) for i in range(eps): action = policy.get_action(obs) next_obs, reward, done, _ = env.step(action) memory.store((obs, action, reward, done, next_obs)) if done: obs = env.reset() model.init_states.store(obs) else: obs = next_obs optimizer = torch.optim.Adam(model.parameters(), lr=lr) total_losses = [] for i in range(train_eps): indices = np.arange(len(memory)) np.random.shuffle(indices) for batch_start in range(0, len(memory), batch_size): batch = memory.get_data(indices[batch_start:batch_start + batch_size]) states, actions, rewards, dones, next_states = batch states_tensor = to_tensor(states) actions_tensor = to_tensor(actions) rewards_tensor = to_tensor(rewards) dones_tensor = to_tensor(dones) next_states_tensor = to_tensor(next_states) pred_ns, pred_r, _, pred_d_logits = model(states_tensor, actions_tensor) ns_loss = nn.MSELoss()(pred_ns, next_states_tensor) r_loss = nn.MSELoss()(pred_r.squeeze(dim=1), rewards_tensor) d_loss = nn.BCEWithLogitsLoss()(pred_d_logits.squeeze(dim=1), dones_tensor) total_loss = ns_loss + r_loss + 100 * d_loss model.zero_grad() total_loss.backward() optimizer.step() total_losses.append(total_loss) return total_losses
def select_action(self, obs, train=True): if train and random.random() < self.eps: return random.randrange(self.n_actions) obs_tensor = to_tensor(obs) return torch.argmax(self.q_net(obs_tensor)).numpy()
def __init__(self, state_size, action_range, hidden_layers): super().__init__() self.action_range = to_tensor(action_range) self.net = MLP((state_size[0], *hidden_layers, action_size[0]))
def get_action(self, obs): logits = self.forward(to_tensor(obs)) return Categorical(logits=logits).sample().numpy()