def __init__(self, observation_space, action_space, lr=1e-3, gamma=0.99, tau=0.01): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.gamma = gamma self.tau = tau self.beta = 0.6 self.memory = PrioritizedReplayBuffer(100000, 0.5) self.action_space = action_space self.epsilon = 0.7 self.epsilon_decay = 0.995 self.min_epsilon = 0.01 self.v_min = 0. self.v_max = 500. self.atom_size = 51 self.support = torch.linspace(self.v_min, self.v_max, self.atom_size).to(self.device) self.update_count = 0 self.dqn = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device) self.dqn_target = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device) self.dqn_target.load_state_dict(self.dqn.state_dict()) self.dqn_target.eval() self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr)
def __init__(self, observation_space, action_space, lr=1e-3, gamma=0.99, tau=0.01): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.gamma = gamma self.tau = tau self.memory = PrioritizedReplayBuffer(10000, 0.6) self.beta = 0.6 self.update_count = 0 self.dqn = NoisyNetwork(observation_space.shape[0], action_space.n).to(self.device) self.dqn_target = NoisyNetwork(observation_space.shape[0], action_space.n).to(self.device) self.dqn_target.load_state_dict(self.dqn.state_dict()) self.dqn_target.eval() self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr)
class DuelingDDQN: def __init__(self, observation_space, action_space, lr=1e-3, gamma=0.99, tau=0.01): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.gamma = gamma self.tau = tau self.memory = PrioritizedReplayBuffer(100000, 0.6) self.action_space = action_space self.beta = 0.6 self.epsilon = 0.7 self.epsilon_decay = 0.995 self.min_epsilon = 0.01 self.update_count = 0 self.dqn = Network(observation_space.shape[0], action_space.n).to(self.device) self.dqn_target = Network(observation_space.shape[0], action_space.n).to(self.device) self.dqn_target.load_state_dict(self.dqn.state_dict()) self.dqn_target.eval() self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr) def act(self, state): self.epsilon *= self.epsilon_decay self.epsilon = max(self.epsilon, self.min_epsilon) if np.random.random() < self.epsilon: action = [self.action_space.sample() for i in range(len(state))] return action state = torch.FloatTensor(state).to(self.device) action = self.dqn.forward(state).argmax(dim=-1) action = action.cpu().detach().numpy() return action def remember(self, states, actions, rewards, new_states, dones): for i in range(len(states)): self.memory.add(states[i], actions[i], rewards[i], new_states[i], dones[i]) def train(self, batch_size=32, epochs=1): if 1000 > len(self.memory._storage): return for epoch in range(epochs): self.update_count +=1 self.beta = self.beta + self.update_count/100000 * (1.0 - self.beta) (states, actions, rewards, next_states, dones, weights, batch_indexes) = self.memory.sample(batch_size, self.beta) states = torch.FloatTensor(states).to(self.device) actions = torch.FloatTensor(actions).unsqueeze(-1).to(self.device) rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device) weights = torch.FloatTensor(weights).unsqueeze(-1).to(self.device) q = self.dqn.forward(states).gather(-1, actions.long()) a2 = self.dqn.forward(next_states).argmax(dim=-1, keepdim=True) q2 = self.dqn_target.forward(next_states).gather(-1, a2).detach() target = (rewards + (1 - dones) * self.gamma * q2).to(self.device) td_error = F.mse_loss(q, target, reduction="none") loss = torch.mean(td_error * weights) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.update_target() priorities = td_error.detach().cpu().numpy() + 1e-6 self.memory.update_priorities(batch_indexes, priorities) def update_target(self): with torch.no_grad(): for target_param, param in zip(self.dqn_target.parameters(), self.dqn.parameters()): target_param.data.mul_(1 - self.tau) torch.add(target_param.data, param.data, alpha=self.tau, out=target_param.data)
class CategoricalDQN: def __init__(self, observation_space, action_space, lr=1e-3, gamma=0.99, tau=0.01): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.gamma = gamma self.tau = tau self.beta = 0.6 self.memory = PrioritizedReplayBuffer(100000, 0.5) self.action_space = action_space self.epsilon = 0.7 self.epsilon_decay = 0.995 self.min_epsilon = 0.01 self.v_min = 0. self.v_max = 500. self.atom_size = 51 self.support = torch.linspace(self.v_min, self.v_max, self.atom_size).to(self.device) self.update_count = 0 self.dqn = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device) self.dqn_target = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device) self.dqn_target.load_state_dict(self.dqn.state_dict()) self.dqn_target.eval() self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr) def act(self, state): self.epsilon *= self.epsilon_decay self.epsilon = max(self.epsilon, self.min_epsilon) if np.random.random() < self.epsilon: action = [self.action_space.sample() for i in range(len(state))] return action state = torch.FloatTensor(state).to(self.device) action = self.dqn.forward(state).argmax(dim=-1) action = action.cpu().detach().numpy() return action def remember(self, states, actions, rewards, new_states, dones): for i in range(len(states)): self.memory.add(states[i], actions[i], rewards[i], new_states[i], dones[i]) def train(self, batch_size=32, epochs=1): if 1000 > len(self.memory._storage): return for epoch in range(epochs): self.update_count += 1 self.beta = self.beta + self.update_count / 100000 * (1.0 - self.beta) (states, actions, rewards, next_states, dones, weights, batch_indexes) = self.memory.sample(batch_size, self.beta) states = torch.FloatTensor(states).to(self.device) actions = torch.LongTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device) weights = torch.FloatTensor(weights).unsqueeze(-1).to(self.device) delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1) with torch.no_grad(): next_action = self.dqn_target.forward(next_states).argmax( dim=1) next_dist = self.dqn_target.dist(next_states) next_dist = next_dist[range(batch_size), next_action] t_z = rewards + (1 - dones) * self.gamma * self.support t_z = t_z.clamp(min=self.v_min, max=self.v_max) b = (t_z - self.v_min) / delta_z l = b.floor().long() u = b.ceil().long() offset = torch.linspace(0, (batch_size - 1) * self.atom_size, batch_size) offset = offset.long().unsqueeze(1).expand( batch_size, self.atom_size).to(self.device) proj_dist = torch.zeros(next_dist.size(), device=self.device) proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) dist = self.dqn.dist(states) log_p = torch.log(dist[range(batch_size), actions]) td_error = -(proj_dist * log_p).sum(1) loss = torch.mean(td_error * weights) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.update_target() priorities = td_error.detach().cpu().numpy() + 1e-6 self.memory.update_priorities(batch_indexes, priorities) def update_target(self): for target_param, param in zip(self.dqn_target.parameters(), self.dqn.parameters()): target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))
class Rainbow: def __init__(self, observation_space, action_space, lr=7e-4, gamma=0.99, tau=0.01, n_step=3, n_envs=1): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.gamma = gamma self.tau = tau self.beta = 0.6 self.memory = PrioritizedReplayBuffer(10000, 0.5) self.n_step = n_step self.action_space = action_space self.v_min = 0. self.v_max = 500. self.atom_size = 51 self.support = torch.linspace(self.v_min, self.v_max, self.atom_size).to(self.device) self.update_count = 0 self.dqn = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device) self.dqn_target = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device) self.dqn_target.load_state_dict(self.dqn.state_dict()) self.dqn_target.eval() self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr) def act(self, state): state = torch.FloatTensor(state).to(self.device) action = self.dqn.forward(state).argmax(dim=-1) action = action.cpu().detach().numpy() return action def remember(self, states, actions, rewards, new_states, dones): for i in range(len(states)): self.memory.add(states[i], actions[i], rewards[i], new_states[i], dones[i]) def train(self, batch_size=32): if 500 > len(self.memory._storage): return self.update_count +=1 self.beta = self.beta + self.update_count/100000 * (1.0 - self.beta) (states, actions, rewards, next_states, dones, weights, batch_indexes) = self.memory.sample(batch_size, self.beta) weights = torch.FloatTensor(weights).unsqueeze(-1).to(self.device) td_error = self.calculate_loss(states, actions, rewards, next_states, dones, self.gamma) # ** self.n_step) # gamma = self.gamma ** self.n_step # (states, actions, rewards, next_states, dones) = self.memory_n.sample_batch_from_idxs(batch_indexes) # n_loss = self.calculate_loss(states, actions, rewards, next_states, dones, gamma) # td_error += n_loss loss = torch.mean(td_error * weights) self.optimizer.zero_grad() loss.backward() clip_grad_norm_(self.dqn.parameters(), 10.0) self.optimizer.step() self.update_target() priorities = td_error.detach().cpu().numpy() + 1e-6 self.memory.update_priorities(batch_indexes, priorities) self.dqn.reset_noise() self.dqn_target.reset_noise() def calculate_loss(self, states, actions, rewards, next_states, dones, gamma): states = torch.FloatTensor(states).to(self.device) actions = torch.LongTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device) delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1) with torch.no_grad(): next_action = self.dqn_target.forward(next_states).argmax(dim=1) next_dist = self.dqn_target.dist(next_states) next_dist = next_dist[range(len(states)), next_action] t_z = rewards + (1 - dones) * gamma * self.support t_z = t_z.clamp(min=self.v_min, max=self.v_max) b = (t_z - self.v_min) / delta_z l = b.floor().long() u = b.ceil().long() offset = torch.linspace(0, (len(states) - 1) * self.atom_size, len(states)) offset = offset.long().unsqueeze(1).expand(len(states), self.atom_size).to(self.device) proj_dist = torch.zeros(next_dist.size(), device=self.device) proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) dist = self.dqn.dist(states) log_p = torch.log(dist[range(len(states)), actions]) td_error = -(proj_dist * log_p).sum(1) return td_error def update_target(self): with torch.no_grad(): for target_param, param in zip(self.dqn_target.parameters(), self.dqn.parameters()): target_param.data.mul_(1 - self.tau) torch.add(target_param.data, param.data, alpha=self.tau, out=target_param.data) def hard_update_target(self): self.dqn_target.load_state_dict(self.dqn.state_dict()) def save_model(self, path): torch.save(self.dqn.state_dict(), path) def load_model(self, path): self.dqn.load_state_dict(torch.load(path))