Exemplo n.º 1
0
class Agent:
    def __init__(self, env, state_space, action_space, device, learning_rate, buffer_size, \
                 batch_size, gamma, in_channels, train_freq = 4, target_update_freq=1e4, is_ddqn = False):
        self.env = env
        self.state_space = state_space
        self.action_space = action_space
        
        self.QNetwork_local = QNetwork(in_channels, self.state_space, self.action_space.n, device = device).to(device)
        self.QNetwork_local.init_weights()
        self.QNetwork_target = QNetwork(in_channels, self.state_space,self.action_space.n , device = device).to(device)
        self.QNetwork_target.load_state_dict(self.QNetwork_local.state_dict())
        
        self.optimizer = torch.optim.RMSprop(self.QNetwork_local.parameters(), lr=learning_rate, alpha=0.95, eps=0.01, centered=True)
        self.criterion = torch.nn.MSELoss()
        self.memory = ReplayBuffer(capacity=int(buffer_size), batch_size=batch_size)
        self.step_count = 0.
        self.batch_size = batch_size
        self.gamma = gamma
        
        self.device = device
        self.buffer_size = buffer_size
        self.num_train_updates = 0
        self.train_freq = train_freq
        self.target_update_freq = target_update_freq

        self.is_ddqn = is_ddqn
        
        print('Agent initialized with memory size:{}'.format(buffer_size))
        
    def act(self, state, epsilon):
        if random.random() > epsilon:
            #switch to evaluation mode, evaluate state, switch back to train mode
            self.QNetwork_local.eval()
            if torch.no_grad():
                actions = self.QNetwork_local(state)
            self.QNetwork_local.train()
            
            actions_np = actions.data.cpu().numpy()[0]
            best_action_idx = int(np.argmax(actions_np))
            return best_action_idx
        else:
            rand_action = self.action_space.sample()
            return rand_action
            
        
    def step(self, state, action, reward, next_state , done, add_memory=True):
        #TODO calculate priority here?
        if add_memory: 
            priority = 1.
            reward_clip = np.sign(reward)
            self.memory.add(state=state, action=action, next_state=next_state, reward=reward_clip, done=done, priority=priority)
            self.step_count = (self.step_count+1) % self.train_freq  #self.update_rate
            
            #if self.step_count == 0 and len(self.memory) >= self.batch_size:
            self.network_is_updated = False
            if self.step_count == 0 and len(self.memory) == self.buffer_size:
                samples = self.memory.random_sample(self.device)
                self.learn(samples)
                self.num_train_updates +=1
                self.network_is_updated = True
            
    def learn(self, samples):
        states, actions, rewards, next_states, dones = samples
        

        if self.is_ddqn is True:
            # DDQN: find max action using local network & gather the values of actions from target network
            next_actions = torch.argmax(self.QNetwork_local(next_states).detach(), dim=1).unsqueeze(1)
            q_target_next = self.QNetwork_target(next_states).gather(1,next_actions)
        else:
            # DQN: find the max action from target network
            q_target_next = self.QNetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
    
        # expected actions
        q_local_current = self.QNetwork_local(states).gather(1,actions)

        self.optimizer.zero_grad() #cleans up previous values

        # TD Error 
        TD_target = rewards + (self.gamma*q_target_next * (1-dones))
        TD_error = self.criterion(q_local_current, TD_target)
        TD_error.backward()
        torch.nn.utils.clip_grad_norm_(self.QNetwork_local.parameters(), 5.)
        self.optimizer.step()
        
        if (self.num_train_updates/self.train_freq) % self.target_update_freq == 0:
            self.QNetwork_target.load_state_dict(self.QNetwork_local.state_dict())
Exemplo n.º 2
0
class Agent:
    def __init__(self,
                 device,
                 state_size,
                 action_size,
                 buffer_size=10,
                 batch_size=10,
                 learning_rate=0.1,
                 discount_rate=0.99,
                 eps_decay=0.9,
                 tau=0.1,
                 steps_per_update=4):
        self.device = device
        self.state_size = state_size
        self.action_size = action_size

        self.q_network_control = QNetwork(state_size, action_size).to(device)
        self.q_network_target = QNetwork(state_size, action_size).to(device)
        self.optimizer = torch.optim.Adam(self.q_network_control.parameters(),
                                          lr=learning_rate)

        self.batch_size = batch_size
        self.replay_buffer = ReplayBuffer(device, state_size, action_size,
                                          buffer_size)

        self.discount_rate = discount_rate

        self.eps = 1.0
        self.eps_decay = eps_decay

        self.tau = tau

        self.step_count = 0
        self.steps_per_update = steps_per_update

    def policy(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        return self.epsilon_greedy_policy(self.eps, state)

    def epsilon_greedy_policy(self, eps, state):
        self.q_network_control.eval()
        with torch.no_grad():
            action_values = self.q_network_control(state)
        self.q_network_control.train()

        if random.random() > eps:
            greedy_choice = np.argmax(action_values.cpu().data.numpy())
            return greedy_choice
        else:
            return random.choice(np.arange(self.action_size))

    def step(self, state, action, reward, next_state, done):
        p = self.calculate_p(state, action, reward, next_state, done)
        self.replay_buffer.add(state, action, reward, next_state, done, p)
        if self.step_count % self.steps_per_update == 0:
            self.learn()
        self.step_count += 1

    def learn(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        states, actions, rewards, next_states, dones, p = \
            self.replay_buffer.sample(self.batch_size)

        error = self.bellman_eqn_error(states, actions, rewards, next_states,
                                       dones)
        importance_scaling = (self.replay_buffer.buffer_size * p)**-1
        loss = (importance_scaling * (error**2)).sum() / self.batch_size
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_target()

    def bellman_eqn_error(self, states, actions, rewards, next_states, dones):
        """Double DQN error - use the control network to get the best action
        and apply the target network to it to get the target reward which is
        used for the bellman eqn error.
        """
        self.q_network_control.eval()
        with torch.no_grad():
            a_max = self.q_network_control(next_states).argmax(1).unsqueeze(1)

        target_action_values = self.q_network_target(next_states).gather(
            1, a_max)
        target_rewards = rewards + self.discount_rate * (1 - dones) \
                         * target_action_values

        self.q_network_control.train()
        current_rewards = self.q_network_control(states).gather(1, actions)
        error = current_rewards - target_rewards
        return error

    def calculate_p(self, state, action, reward, next_state, done):
        next_state = torch.from_numpy(next_state[np.newaxis, :]).float().to(
            self.device)
        state = torch.from_numpy(state[np.newaxis, :]).float().to(self.device)
        action = torch.from_numpy(np.array([[action]])).long().to(self.device)
        reward = torch.from_numpy(np.array([reward])).float().to(self.device)
        done = torch.from_numpy(np.array([[done]], dtype=np.uint8)).float().to(
            self.device)

        return abs(
            self.bellman_eqn_error(state, action, reward, next_state,
                                   done)) + 1e-3

    def update_target(self):
        for target_param, control_param in zip(
                self.q_network_target.parameters(),
                self.q_network_control.parameters()):
            target_param.data.copy_(self.tau * control_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def end_of_episode(self):
        self.eps *= self.eps_decay
        self.step_count = 0

    def save(self, path):
        torch.save(self.q_network_control.state_dict(), path)

    def restore(self, path):
        self.q_network_control.load_state_dict(torch.load(path))