Example #1
0
    def __init__(self,
                 observation_space,
                 action_space,
                 lr=1e-3,
                 gamma=0.99,
                 tau=0.01):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.gamma = gamma
        self.tau = tau
        self.beta = 0.6
        self.memory = PrioritizedReplayBuffer(100000, 0.5)
        self.action_space = action_space

        self.epsilon = 0.7
        self.epsilon_decay = 0.995
        self.min_epsilon = 0.01

        self.v_min = 0.
        self.v_max = 500.
        self.atom_size = 51
        self.support = torch.linspace(self.v_min, self.v_max,
                                      self.atom_size).to(self.device)

        self.update_count = 0
        self.dqn = Network(observation_space.shape[0], action_space.n,
                           self.atom_size, self.support).to(self.device)
        self.dqn_target = Network(observation_space.shape[0], action_space.n,
                                  self.atom_size, self.support).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr)
Example #2
0
    def __init__(self, observation_space, action_space, lr=1e-3, gamma=0.99, tau=0.01):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.gamma = gamma
        self.tau = tau
        self.memory = PrioritizedReplayBuffer(10000, 0.6)
        self.beta = 0.6

        self.update_count = 0
        self.dqn = NoisyNetwork(observation_space.shape[0], action_space.n).to(self.device)
        self.dqn_target = NoisyNetwork(observation_space.shape[0], action_space.n).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer  = optim.Adam(self.dqn.parameters(), lr=lr)
Example #3
0
class DuelingDDQN:
    def __init__(self, observation_space, action_space, lr=1e-3, gamma=0.99, tau=0.01):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.gamma = gamma
        self.tau = tau
        self.memory = PrioritizedReplayBuffer(100000, 0.6)
        self.action_space = action_space

        self.beta = 0.6
        self.epsilon = 0.7
        self.epsilon_decay = 0.995
        self.min_epsilon = 0.01

        self.update_count = 0
        self.dqn = Network(observation_space.shape[0], action_space.n).to(self.device)
        self.dqn_target = Network(observation_space.shape[0], action_space.n).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer  = optim.Adam(self.dqn.parameters(), lr=lr)

    def act(self, state):
        self.epsilon *= self.epsilon_decay
        self.epsilon = max(self.epsilon, self.min_epsilon)

        if np.random.random() < self.epsilon:
            action = [self.action_space.sample() for i in range(len(state))]
            return action

        state = torch.FloatTensor(state).to(self.device)
        action = self.dqn.forward(state).argmax(dim=-1)
        action = action.cpu().detach().numpy()

        return action

    def remember(self, states, actions, rewards, new_states, dones):
        for i in range(len(states)):
            self.memory.add(states[i], actions[i], rewards[i], new_states[i], dones[i])

    def train(self, batch_size=32, epochs=1):
        if 1000 > len(self.memory._storage):
            return
        
        for epoch in range(epochs):
            self.update_count +=1

            self.beta = self.beta + self.update_count/100000 * (1.0 - self.beta)

            (states, actions, rewards, next_states, dones, weights, batch_indexes) = self.memory.sample(batch_size, self.beta)

            states = torch.FloatTensor(states).to(self.device)
            actions = torch.FloatTensor(actions).unsqueeze(-1).to(self.device)
            rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device)
            next_states = torch.FloatTensor(next_states).to(self.device)
            dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device)
            weights = torch.FloatTensor(weights).unsqueeze(-1).to(self.device)


            q = self.dqn.forward(states).gather(-1, actions.long())
            a2 = self.dqn.forward(next_states).argmax(dim=-1, keepdim=True)
            q2 = self.dqn_target.forward(next_states).gather(-1, a2).detach()

            target = (rewards + (1 - dones) * self.gamma * q2).to(self.device)

            td_error = F.mse_loss(q, target, reduction="none")
            loss = torch.mean(td_error * weights)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.update_target()

            priorities = td_error.detach().cpu().numpy() + 1e-6
            self.memory.update_priorities(batch_indexes, priorities)


    def update_target(self):
        with torch.no_grad():
            for target_param, param in zip(self.dqn_target.parameters(), self.dqn.parameters()):
                target_param.data.mul_(1 - self.tau)
                torch.add(target_param.data, param.data, alpha=self.tau, out=target_param.data)
Example #4
0
class CategoricalDQN:
    def __init__(self,
                 observation_space,
                 action_space,
                 lr=1e-3,
                 gamma=0.99,
                 tau=0.01):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.gamma = gamma
        self.tau = tau
        self.beta = 0.6
        self.memory = PrioritizedReplayBuffer(100000, 0.5)
        self.action_space = action_space

        self.epsilon = 0.7
        self.epsilon_decay = 0.995
        self.min_epsilon = 0.01

        self.v_min = 0.
        self.v_max = 500.
        self.atom_size = 51
        self.support = torch.linspace(self.v_min, self.v_max,
                                      self.atom_size).to(self.device)

        self.update_count = 0
        self.dqn = Network(observation_space.shape[0], action_space.n,
                           self.atom_size, self.support).to(self.device)
        self.dqn_target = Network(observation_space.shape[0], action_space.n,
                                  self.atom_size, self.support).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr)

    def act(self, state):
        self.epsilon *= self.epsilon_decay
        self.epsilon = max(self.epsilon, self.min_epsilon)

        if np.random.random() < self.epsilon:
            action = [self.action_space.sample() for i in range(len(state))]
            return action

        state = torch.FloatTensor(state).to(self.device)
        action = self.dqn.forward(state).argmax(dim=-1)
        action = action.cpu().detach().numpy()

        return action

    def remember(self, states, actions, rewards, new_states, dones):
        for i in range(len(states)):
            self.memory.add(states[i], actions[i], rewards[i], new_states[i],
                            dones[i])

    def train(self, batch_size=32, epochs=1):
        if 1000 > len(self.memory._storage):
            return

        for epoch in range(epochs):
            self.update_count += 1

            self.beta = self.beta + self.update_count / 100000 * (1.0 -
                                                                  self.beta)

            (states, actions, rewards, next_states, dones, weights,
             batch_indexes) = self.memory.sample(batch_size, self.beta)

            states = torch.FloatTensor(states).to(self.device)
            actions = torch.LongTensor(actions).to(self.device)
            rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device)
            next_states = torch.FloatTensor(next_states).to(self.device)
            dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device)
            weights = torch.FloatTensor(weights).unsqueeze(-1).to(self.device)

            delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)

            with torch.no_grad():
                next_action = self.dqn_target.forward(next_states).argmax(
                    dim=1)
                next_dist = self.dqn_target.dist(next_states)
                next_dist = next_dist[range(batch_size), next_action]

                t_z = rewards + (1 - dones) * self.gamma * self.support
                t_z = t_z.clamp(min=self.v_min, max=self.v_max)
                b = (t_z - self.v_min) / delta_z
                l = b.floor().long()
                u = b.ceil().long()

                offset = torch.linspace(0, (batch_size - 1) * self.atom_size,
                                        batch_size)
                offset = offset.long().unsqueeze(1).expand(
                    batch_size, self.atom_size).to(self.device)

                proj_dist = torch.zeros(next_dist.size(), device=self.device)
                proj_dist.view(-1).index_add_(0, (l + offset).view(-1),
                                              (next_dist *
                                               (u.float() - b)).view(-1))
                proj_dist.view(-1).index_add_(0, (u + offset).view(-1),
                                              (next_dist *
                                               (b - l.float())).view(-1))

            dist = self.dqn.dist(states)
            log_p = torch.log(dist[range(batch_size), actions])
            td_error = -(proj_dist * log_p).sum(1)

            loss = torch.mean(td_error * weights)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.update_target()

            priorities = td_error.detach().cpu().numpy() + 1e-6
            self.memory.update_priorities(batch_indexes, priorities)

    def update_target(self):
        for target_param, param in zip(self.dqn_target.parameters(),
                                       self.dqn.parameters()):
            target_param.data.copy_(param.data * self.tau + target_param.data *
                                    (1.0 - self.tau))
Example #5
0
class Rainbow:
    def __init__(self, observation_space, action_space, lr=7e-4, gamma=0.99, tau=0.01, n_step=3, n_envs=1):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.gamma = gamma
        self.tau = tau
        self.beta = 0.6
        self.memory = PrioritizedReplayBuffer(10000, 0.5)
        self.n_step = n_step
        self.action_space = action_space

        self.v_min = 0.
        self.v_max = 500.
        self.atom_size = 51
        self.support = torch.linspace(self.v_min, self.v_max, self.atom_size).to(self.device)

        self.update_count = 0
        self.dqn = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device)
        self.dqn_target = Network(observation_space.shape[0], action_space.n, self.atom_size, self.support).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer  = optim.Adam(self.dqn.parameters(), lr=lr)

    def act(self, state):
        state = torch.FloatTensor(state).to(self.device)
        action = self.dqn.forward(state).argmax(dim=-1)
        action = action.cpu().detach().numpy()

        return action

    def remember(self, states, actions, rewards, new_states, dones):
        for i in range(len(states)):
            self.memory.add(states[i], actions[i], rewards[i], new_states[i], dones[i])

    def train(self, batch_size=32):
        if 500 > len(self.memory._storage):
            return
        
        self.update_count +=1

        self.beta = self.beta + self.update_count/100000 * (1.0 - self.beta)

        (states, actions, rewards, next_states, dones, weights, batch_indexes) = self.memory.sample(batch_size, self.beta)
        weights = torch.FloatTensor(weights).unsqueeze(-1).to(self.device)

        td_error = self.calculate_loss(states, actions, rewards, next_states, dones, self.gamma) # ** self.n_step)

        # gamma = self.gamma ** self.n_step
        # (states, actions, rewards, next_states, dones) = self.memory_n.sample_batch_from_idxs(batch_indexes)
        # n_loss = self.calculate_loss(states, actions, rewards, next_states, dones, gamma)

        # td_error += n_loss
        loss = torch.mean(td_error * weights)

        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(self.dqn.parameters(), 10.0)
        self.optimizer.step()

        self.update_target()

        priorities = td_error.detach().cpu().numpy() + 1e-6
        self.memory.update_priorities(batch_indexes, priorities)

        self.dqn.reset_noise()
        self.dqn_target.reset_noise()

    def calculate_loss(self, states, actions, rewards, next_states, dones, gamma):
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device)

        delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)

        with torch.no_grad():
            next_action = self.dqn_target.forward(next_states).argmax(dim=1)
            next_dist = self.dqn_target.dist(next_states)
            next_dist = next_dist[range(len(states)), next_action]

            t_z = rewards + (1 - dones) * gamma * self.support
            t_z = t_z.clamp(min=self.v_min, max=self.v_max)
            b = (t_z - self.v_min) / delta_z
            l = b.floor().long()
            u = b.ceil().long()

            offset = torch.linspace(0, (len(states) - 1) * self.atom_size, len(states))
            offset = offset.long().unsqueeze(1).expand(len(states), self.atom_size).to(self.device)

            proj_dist = torch.zeros(next_dist.size(), device=self.device)
            proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
            proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))

        dist = self.dqn.dist(states)
        log_p = torch.log(dist[range(len(states)), actions])
        td_error = -(proj_dist * log_p).sum(1)

        return td_error

    def update_target(self):
        with torch.no_grad():
            for target_param, param in zip(self.dqn_target.parameters(), self.dqn.parameters()):
                target_param.data.mul_(1 - self.tau)
                torch.add(target_param.data, param.data, alpha=self.tau, out=target_param.data)

    def hard_update_target(self):
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def save_model(self, path):
        torch.save(self.dqn.state_dict(), path)
    
    def load_model(self, path):
        self.dqn.load_state_dict(torch.load(path))