class DQNAgent():
    """A deep Q network agent.

    An agent using a deep Q network with a replay buffer, soft target update
    and linear epsilon decay that can learn to solve a task by interacting with
    its environment. Includes option to train using double deep Q learning
    and prioritised replay buffer.
    """

    def __init__(
        self,
        buffer_size,
        seed,
        state_size,
        action_size,
        hidden_layers,
        epsilon,
        epsilon_decay,
        epsilon_min,
        gamma,
        tau,
        learning_rate,
        update_frequency,
        double_Q=False,
        prioritised_replay_buffer=False,
        alpha=None,
        beta=None,
        beta_increment_size=None,
        base_priority=None,
        max_priority=None,
        training_scores=None,
        step_number=0,
    ):
        """DQNAgent initialisation function.

        Args:

            buffer_size (int): maximum size of the replay buffer.
            seed (int): random seed used for batch selection.

            state_size (int): dimension of state space for input to Q network.
            action_size (int): dimension of action space for value predictions.
            hidden_layers (list[int]): list of dimensions for the hidden layers required.

            epsilon (float): probability of choosing non-greedy action in policy.
            epsilon_decay (float): linear decay rate of epsilon with after each step.
            epsilon_min (float): a floor for the decay of epsilon.

            gamma (float): discount factor for future expected returns.
            tau (float): soft update factor used to define how much to shift.
                       target network parameters towards current network parameter.

            learning_rate (float): learning rate for gradient decent optimisation.
            update_frequency (int): how often to update target Q network parameters.

            double_Q (bool): set true to train using double deep Q learning.

            priority_replay_buffer (bool): set true to use priority replay buffer.
            alpha (float): priority scaling hyperparameter.
            beta_zero (float): importance sampling scaling hyperparameter.
            beta_increment_size (float): beta annealing rate.
            base_priority (float): base priority to ensure non-zero sampling probability.
            max_priority (float): initial maximum priority.

            training_scores (list[int]): rewards gained in previous traing episodes. (this is primarily 
                                used to reloading saved agents)
            step_number (int): number of steps the agent has taken. (this is primarily 
                                used to reloading saved agents)

        Notes: Setting tau = 1 will return classic DQN with full target update.
               If using soft updates it is recommended that update frequency is high. 
        """

        self.buffer_size = buffer_size
        self.seed = seed
        if prioritised_replay_buffer:
            self.replay_buffer = PrioritisedReplayBuffer(
                buffer_size,
                alpha,
                beta,
                beta_increment_size,
                base_priority,
                max_priority,
                seed,
            )
        else:
            self.replay_buffer = ReplayBuffer(buffer_size, seed)

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_layers = hidden_layers

        self.Q_net = QNetwork(state_size, action_size, hidden_layers).to(device)
        self.target_Q = QNetwork(state_size, action_size, hidden_layers).to(device)
        self.optimizer = optim.Adam(self.Q_net.parameters(), lr=learning_rate)

        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.gamma = gamma
        self.tau = tau

        self.learning_rate = learning_rate
        self.update_frequency = update_frequency

        self.double_Q = double_Q

        self.prioritised_replay_buffer = prioritised_replay_buffer
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_size = beta_increment_size
        self.base_priority = base_priority
        self.max_priority = max_priority

        self.step_number = step_number
        if training_scores is None:
            self.training_scores = []
        else:
            self.training_scores = training_scores

    def step(self, state, action, reward, next_state, done, batch_size):
        """
        A function that records experiences into the replay buffer after each
        environment step, then update the current network parameter and soft
        updates target network parameters.
        """
        self.replay_buffer.add(state, action, reward, next_state, done)
        self.update_Q(batch_size)
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

        self.step_number += 1
        if self.step_number % self.update_frequency == 0:
            self.soft_update_target_Q()

    def act_epsilon_greedy(self, state, greedy=False):
        """ Returns an epsilon greedy action """
        if greedy or random.random() > self.epsilon:
            state = torch.from_numpy(state).unsqueeze(0).to(device)
            self.Q_net.eval()
            with torch.no_grad():
                action_values = self.Q_net.forward(state)
            self.Q_net.train()
            return torch.argmax(action_values).cpu().item()

        return np.random.randint(self.action_size)

    def update_Q(self, batch_size):
        """
        Updates the parameters of the current Q network using backpropagation
        and experiences from the replay buffer.
        """

        if len(self.replay_buffer) > 2*batch_size:

            experience = self.replay_buffer.sample(batch_size)

            states = torch.FloatTensor(experience[0]).to(device)
            actions = torch.LongTensor(experience[1]).unsqueeze(1).to(device)
            rewards = torch.FloatTensor(experience[2]).unsqueeze(1).to(device)
            next_states = torch.FloatTensor(experience[3]).to(device)
            done_tensor = torch.FloatTensor(experience[4]).unsqueeze(1).to(device)

            target_Q_net_max = torch.max(self.target_Q(next_states).detach(), 1, keepdim=True)

            if self.double_Q:
                target_actions = target_Q_net_max[1]
                Q_target_next = self.Q_net(next_states).detach().gather(1, target_actions)
            else:
                Q_target_next = target_Q_net_max[0]

            Q_expected = self.Q_net(states).gather(1, actions)
            Q_target = rewards + self.gamma * Q_target_next * (1 - done_tensor)

            if self.prioritised_replay_buffer:
                idx_list = experience[5]
                weights = torch.FloatTensor(experience[6]).unsqueeze(1).to(device)
                td_error = (Q_target - Q_expected)
                priority_list = torch.abs(td_error.squeeze().detach()).cpu().numpy()
                self.replay_buffer.update(idx_list, priority_list)
                loss = torch.mean((weights*td_error)**2)
            else:
                loss = F.mse_loss(Q_expected, Q_target)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def soft_update_target_Q(self):
        """ Soft updates the target Q network """
        for target_Q_param, Q_param in zip(self.target_Q.parameters(), self.Q_net.parameters()):
            target_Q_param.data = self.tau * Q_param.data + (1 - self.tau) * target_Q_param.data

    def save_agent(self, name, path=""):
        """ Saves agent parameters for loading using load_agent function
        Note: it is torch convention to save models with .pth extension
        """
        params = (
            self.buffer_size,
            self.seed,
            self.state_size,
            self.action_size,
            self.hidden_layers,
            self.epsilon,
            self.epsilon_decay,
            self.epsilon_min,
            self.gamma,
            self.tau,
            self.learning_rate,
            self.update_frequency,
            self.double_Q,
            self.prioritised_replay_buffer,
            self.alpha,
            self.replay_buffer.beta,
            self.beta_increment_size,
            self.base_priority,
            self.max_priority,
            self.training_scores,
            self.step_number,
        )

        checkpoint = {
            "params": params,
            "state_dict": self.Q_net.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }

        path += name
        torch.save(checkpoint, path)
Exemplo n.º 2
0
class DDPGAgent():
    def __init__(self,
                 model_fn,
                 action_scale=1.0,
                 gamma=0.99,
                 exploration_noise=None,
                 batch_size=64,
                 replay_memory=100000,
                 replay_start=1000,
                 tau=1e-3,
                 optimizer=optim.Adam,
                 actor_learning_rate=1e-4,
                 critic_learning_rate=1e-3,
                 clip_gradients=None,
                 action_repeat=1,
                 update_freq=1,
                 random_seed=None):
        # create online and target networks
        self.online_network = model_fn()
        self.target_network = model_fn()

        # create the optimizers for the online_network
        self.actor_optimizer = optimizer(self.online_network.actor_params,
                                         lr=actor_learning_rate)
        self.critic_optimizer = optimizer(self.online_network.critic_params,
                                          lr=critic_learning_rate)
        self.clip_gradients = clip_gradients

        # assign the online network variables to the target network
        self.target_network.load_state_dict(self.online_network.state_dict())

        self.replay_buffer = ReplayBuffer(memory_size=replay_memory,
                                          seed=random_seed)
        self.exploration_noise = exploration_noise

        self.action_scale = action_scale
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.action_repeat = action_repeat
        self.update_freq = update_freq
        self.random_seed = random_seed

        self.reset_current_step()

    def reset_current_step(self):
        self.current_step = 0

    def soft_update(self):
        for target_param, online_param in zip(
                self.target_network.parameters(),
                self.online_network.parameters()):
            target_param.detach_()
            target_param.copy_(target_param * (1.0 - self.tau) +
                               online_param * self.tau)

    def add_to_replay_memory(self, state, action, reward, next_state,
                             terminal):
        experience = (state, action, reward, next_state, terminal)
        self.replay_buffer.add(experience)

    def action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

        if (self.current_step % self.action_repeat
                == 0) or (not hasattr(self, '_previous_action')):
            action = self.online_network(state)
            action = action.squeeze().detach().numpy()

            if self.exploration_noise is not None:
                action = action + self.exploration_noise.sample()
                action = np.clip(action, -self.action_scale, self.action_scale)
        else:
            action = self._previous_action

        self._previous_action = action

        return action

    def update_target(self, state, action, reward, next_state, terminal):
        next_action = self.target_network(next_state).detach()
        Q_sa_next = self.target_network.critic_value(next_state,
                                                     next_action).detach()

        update_target = reward.unsqueeze(
            -1) + self.gamma * Q_sa_next * (1 - terminal).unsqueeze(-1)
        update_target = torch.tensor(update_target, dtype=torch.float32)

        return update_target

    def update(self, state, action, reward, next_state, terminal):
        self.add_to_replay_memory(state, action, reward, next_state, terminal)

        if terminal and (self.exploration_noise is not None):
            try:
                self.exploration_noise[i].reset_states()
            except:
                pass

        if self.current_step >= self.replay_start:
            if self.current_step % self.update_freq == 0:
                experiences = self.replay_buffer.sample(self.batch_size)
                state, action, reward, next_state, terminal = zip(*experiences)

                state = torch.tensor(state, dtype=torch.float32)
                action = torch.tensor(action, dtype=torch.float32)
                reward = torch.tensor(reward, dtype=torch.float32)
                next_state = torch.tensor(next_state, dtype=torch.float32)
                terminal = torch.tensor(terminal, dtype=torch.float32)

                update_target = self.update_target(state, action, reward,
                                                   next_state, terminal)
                Q_sa = self.online_network.critic_value(state, action)
                critic_loss = (Q_sa -
                               update_target).pow(2).mul(0.5).sum(-1).mean()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                if self.clip_gradients:
                    nn.utils.clip_grad_norm_(self.online_network.critic_params,
                                             self.clip_gradients)
                self.critic_optimizer.step()

                action = self.online_network(state)
                policy_loss = -self.online_network.critic_value(state,
                                                                action).mean()

                self.actor_optimizer.zero_grad()
                policy_loss.backward()
                if self.clip_gradients:
                    nn.utils.clip_grad_norm_(self.online_network.actor_params,
                                             self.clip_gradients)
                self.actor_optimizer.step()

                self.soft_update()

        self.current_step += 1
Exemplo n.º 3
0
class MADDPGAgent():
    def __init__(self, n_agents, model_fn,
                 action_scale = 1.0,
                 gamma = 0.99,
                 exploration_noise_fn = None,
                 batch_size = 64,
                 replay_memory = 100000,
                 replay_start = 100,
                 tau = 1e-3,
                 optimizer = optim.Adam,
                 actor_learning_rate = 1e-4,
                 critic_learning_rate = 1e-3,
                 clip_gradients = None,
                 share_weights = False,
                 action_repeat = 1,
                 update_freq = 1,
                 random_seed = None):
        # create online and target networks for each agent
        self.n_agents = n_agents
        
        self.online_networks = [model_fn() for _ in range(self.n_agents)]
        self.target_networks = [model_fn() for _ in range(self.n_agents)]
        
        self.actor_optimizers = [optimizer(agent.actor_params, 
                                           lr = actor_learning_rate) for agent in self.online_networks]
        self.critic_optimizers = [optimizer(agent.critic_params, 
                                            lr = critic_learning_rate) for agent in self.online_networks]
        
        if exploration_noise_fn:
            self.exploration_noise = [exploration_noise_fn() for _ in range(self.n_agents)]
        else:
            self.exploration_noise = None
                                         
        # assign the online network variables to the target network
        for target_network, online_network in zip(self.target_networks, self.online_networks):
            target_network.load_state_dict(online_network.state_dict())
        
        self.replay_buffer = ReplayBuffer(memory_size = replay_memory, seed = random_seed)
        
        self.share_weights = share_weights
        self.action_scale = action_scale
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.clip_gradients = clip_gradients
        self.replay_start = replay_start
        self.action_repeat = action_repeat
        self.update_freq = update_freq
        self.random_seed = random_seed
        
        self.reset_current_step()
        
    def reset_current_step(self):
        self.current_step = 0
        
    def soft_update(self):
        for target_network, online_network in zip(self.target_networks, self.online_networks):
            for target_param, online_param in zip(target_network.parameters(), online_network.parameters()):
                target_param.detach_()
                target_param.copy_(target_param * (1.0 - self.tau) + online_param * self.tau)
                
    def assign_weights(self):
        for target_network, online_network in zip(self.target_networks, self.online_networks):
            target_network.load_state_dict(self.target_networks[0].state_dict())
            online_network.load_state_dict(self.online_networks[0].state_dict())
    
    def add_to_replay_memory(self, state, action, reward, next_state, terminal):
        experience = (state, action, reward, next_state, terminal)
        self.replay_buffer.add(experience)
    
    def action(self, state):
        if (self.current_step % self.action_repeat == 0) or (not hasattr(self, '_previous_action')):
            actions = []
            for i in range(self.n_agents):
                obs = torch.tensor(state[i], dtype = torch.float32).unsqueeze(0)
                action = self.online_networks[i](obs)
                action = action.squeeze().detach().numpy()
                
                if self.exploration_noise:
                    action = action + self.exploration_noise[i].sample()
                    action = np.clip(action, -self.action_scale, self.action_scale)
                actions.append(action)
                
            action = np.asarray(actions)
        else:
            action = self._previous_action
                
        self._previous_action = action

        return action
    
    def update_target(self, state, action, reward, next_state, terminal):
        all_next_actions = []
        for i in range(self.n_agents):
            next_action = self.target_networks[i](next_state[:, i, :])
            all_next_actions.append(next_action)

        all_next_actions = torch.cat(all_next_actions, dim = 1)
        all_next_states = next_state.view(-1, next_state.shape[1] * next_state.shape[2])

        Q_sa_next = self.target_networks[self.current_agent].critic_value(all_next_states, all_next_actions)

        reward = reward[:, self.current_agent].unsqueeze(-1)
        terminal = terminal[:, self.current_agent].unsqueeze(-1)
 
        update_target = reward + self.gamma * Q_sa_next * (1 - terminal)
        update_target = update_target.detach()
        
        return update_target
    
    def update(self, state, action, reward, next_state, terminal):
        self.add_to_replay_memory(state, action, reward, next_state, terminal)
        
        if np.any(terminal) and (self.exploration_noise is not None):
            for i in range(self.n_agents):
                try:
                    self.exploration_noise[i].reset_states()
                except:
                    pass
        
        if self.current_step >= self.replay_start:
            if self.current_step % self.update_freq == 0:
                if self.share_weights:
                    update_agents = 1
                else:
                    update_agents = self.n_agents

                for i in range(update_agents):
                    self.current_agent = i
                
                    experiences = self.replay_buffer.sample(self.batch_size)     
                    state, action, reward, next_state, terminal = zip(*experiences)
                    
                    state = torch.tensor(state, dtype = torch.float32)
                    action = torch.tensor(action, dtype = torch.float32)
                    reward = torch.tensor(reward, dtype = torch.float32)
                    next_state = torch.tensor(next_state, dtype = torch.float32)
                    terminal = torch.tensor(terminal, dtype = torch.float32)

                    all_actions = action.view(-1, action.shape[1] * action.shape[2])
                    all_states = state.view(-1, state.shape[1] * state.shape[2])
                    
                    update_target = self.update_target(state, action, reward, next_state, terminal)
                    
                    Q_sa = self.online_networks[self.current_agent].critic_value(all_states, all_actions)
                    critic_loss = F.mse_loss(Q_sa, update_target)
                    
                    self.critic_optimizers[self.current_agent].zero_grad()
                    critic_loss.backward()
                    if self.clip_gradients:
                        nn.utils.clip_grad_norm_(self.online_networks[self.current_agent].critic_params, self.clip_gradients)
                    self.critic_optimizers[self.current_agent].step()

                    agent_action = self.online_networks[self.current_agent](state[:, self.current_agent, :])

                    predicted_actions = action.clone().detach()
                    predicted_actions[:, self.current_agent] = agent_action
                    predicted_actions = predicted_actions.view(-1, predicted_actions.shape[1] * predicted_actions.shape[2])

                    policy_loss = -self.online_networks[self.current_agent].critic_value(all_states, predicted_actions).mean()

                    self.actor_optimizers[self.current_agent].zero_grad()
                    policy_loss.backward()
                    if self.clip_gradients:
                        nn.utils.clip_grad_norm_(self.online_networks[self.current_agent].actor_params, self.clip_gradients)
                    self.actor_optimizers[self.current_agent].step()
                    
                self.soft_update()
                
                if self.share_weights:
                    self.assign_weights()
            
        self.current_step += 1
Exemplo n.º 4
0
class DQNAgent():
    """Interacts with and learns from the environment."""

    def __init__(self,
                 state_size,
                 action_size,
                 buffer_size,
                 batch_size,
                 gamma,
                 tau,
                 lr,
                 hidden_1,
                 hidden_2,
                 update_every,
                 epsilon,
                 epsilon_min,
                 eps_decay,
                 seed
                 ):
        """Initialize an Agent object.
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.update_every = update_every
        self.seed = random.seed(seed)
        self.learn_steps = 0
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.eps_decay = eps_decay

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, seed, hidden_1, hidden_2).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed, hidden_1, hidden_2).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)

        # Replay memory
        self.memory = ReplayBuffer(self.action_size, self.buffer_size, self.batch_size, self.seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0


    def step(self, state, action, reward, next_state,  done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            # Sample if enough samples are available
            if len(self.memory) > self.batch_size:
                experiences = self.memory.sample()
                self.learn(experiences)

    def act(self, state):
        """Returns actions for given state as per current policy.
        Params
        ======
            state (array_like): current state
        """
        self.epsilon = max(self.epsilon*self.eps_decay, self.epsilon_min)

        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > self.epsilon:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.learn_steps += 1

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
Exemplo n.º 5
0
class DQNAgent():
    def __init__(self,
                 model,
                 model_params,
                 state_processor,
                 n_actions,
                 gamma=0.99,
                 epsilon=1.0,
                 min_epsilon=1e-2,
                 epsilon_decay=.999,
                 loss_function=F.smooth_l1_loss,
                 optimizer=optim.Adam,
                 learning_rate=1e-3,
                 l2_regularization=0.0,
                 batch_size=32,
                 replay_memory=1000000,
                 replay_start=50000,
                 target_update_freq=1000,
                 action_repeat=4,
                 update_freq=4,
                 random_seed=None):
        '''
        DQN Agent from https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
        
        model: pytorch model class
            callable model class for the online and target networks
            the DQN agent instantiates each model
            
        model_params: dict
            dictionary of parameters used to define the model class (e.g., feature 
            space size, action space size, etc.)
            this should be the only input to instantiate the model class
            
        state_processor: function
            callable function that takes state as the input and outputs the processed state
            to use as a feature for the model
            the processed states are stored as experiences in the replay buffer
            
        n_actions: int
            the number of actions the agent can perform
            
        gamma: float, [0.0, 1.0]
            discount rate parameter
            
        epsilon: float, [0.0, 1.0]
            epsilon used to compute the epsilon-greedy policy
            
        min_epsilon: float, [0.0, 1.0]
            minimun value for epsilon over all episodes
            
        epsilon_decay: float, (0.0, 1.0]
            rate at which to decay epsilon after each episodes
            1.0 corresponds to no decay
            
        loss_function: pytorch loss (usually the functional form)
            callable loss function that takes inputs, targets as positional arguments
            
        optimizer: pytorch optimizer
            callable optimizer that takes the learning rate as a parameter
            
        learning_rate: float
            learning rate for the optimizer
            
        l2_regularization: float
            hyperparameter for L2 regularization
            
        batch_size: int
            batch size parameter for training the online network
            
        replay_memory: int
            maximum size of the replay memory
            
        replay_start: int
            number of actions to take/experiences to store before beginning to train 
            the online network 
            this should be larger than the batch size to avoid the same experience
            showing up multiple times in the batch
            
        target_update_freq: int
            the frequency at which the target network is updated with the online
            network's weights
            
        action_repeat: int
            the number of times to repeat the same action
            
        update_freq: int
            the number of steps between each SGD (or other optimization) update
            
        seed: None or int
            random seed for the replay buffer
        '''
        self.n_actions = n_actions
        self.actions = np.arange(self.n_actions)

        self.state_processor = state_processor
        self.gamma = gamma
        self.epsilon = epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.target_update_freq = target_update_freq
        self.action_repeat = action_repeat
        self.update_freq = update_freq

        self.reset_current_step()

        self.replay_buffer = ReplayBuffer(memory_size=replay_memory,
                                          seed=random_seed)

        self.online_network = model(model_params)
        self.target_network = model(model_params)
        self.assign_variables()

        self.loss_function = loss_function
        self.optimizer = optimizer(self.online_network.parameters(),
                                   lr=learning_rate,
                                   weight_decay=l2_regularization)

    def assign_variables(self):
        '''
        Assigns the variables (weights and biases) of the online network to the target networl
        '''
        self.target_network.load_state_dict(self.online_network.state_dict())

    def reset_current_step(self):
        '''
        Set the current_step attribute to 0
        '''
        self.current_step = 0

    def process_state(self, state):
        '''
        Process the state provided by the environment into the feature used by the 
        online and target networks
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        '''
        processed_state = self.state_processor(state)

        return processed_state

    def add_to_replay_memory(self, state, action, reward, next_state,
                             terminal):
        '''
        Add the state, action, reward, next_state, terminal tuple to the replay buffer
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        action: int, provided by the environment
            index of the action taken by the agent
            
        reward: float, provided by the environment
            reward for the given state, action, next state transition
            
        next_state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        
        terminal: bool, usually provided by the environment
            whether or not the current episode has ended
        '''
        processed_state = self.process_state(state)
        processed_next_state = self.process_state(next_state)

        experience = (processed_state, action, reward, processed_next_state,
                      terminal)
        self.replay_buffer.add(experience)

    def action(self, state, mode='train'):
        '''
        Selects an action according to the greedy or epsilon-greedy policy
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        mode: 'train' or 'test'
            selects an action acording to the epsilon-greedy policy when set to 'train'
            selects an action acording to the greedy policy when set to 'test'
        '''
        if (self.current_step % self.action_repeat
                == 0) or (not hasattr(self, 'previous_action')):
            if mode == 'test':
                state_policy, action = self.greedy_policy(state)
            else:
                state_policy, action = self.epsilon_greedy_policy(state)
        else:
            action = self.previous_action

        self.previous_action = action

        return action

    def greedy_policy(self, state):
        '''
        Returns the greedy policy as a discrete probability distribution and the 
        greedy action
        All actions except the greedy action have probablity 0
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        '''
        Q_s = self.estimate_q(state, process_state=True)

        action = np.argmax(Q_s)
        policy = np.zeros(self.n_actions)
        policy[action] = 1.0

        return policy, action

    def epsilon_greedy_policy(self, state):
        '''
        Returns the epsilon-greedy policy as a discrete probability distribution and
        an action randomly selected according to the probability distribution
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        '''
        Q_s = self.estimate_q(state, process_state=True)

        policy = np.ones(self.n_actions) * self.epsilon / self.n_actions
        policy[np.argmax(
            Q_s)] = 1.0 - self.epsilon + self.epsilon / self.n_actions

        action = np.random.choice(self.actions, p=policy)

        return policy, action

    def estimate_q(self, state, process_state=True):
        '''
        Estimates the Q values for a given state and all actions from the online network
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        process_state: bool
            whether to process the state before estimating Q_s
        '''
        if process_state:
            processed_state = self.process_state(state)
        else:
            processed_state = state

        with torch.no_grad():
            Q_s = self.online_network(processed_state)

        return Q_s

    def estimate_target_q(self, state, process_state=True):
        '''
        Estimates the Q values for a given state and all actions from the target network
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        process_state: bool
            whether to process the state before estimating Q_s
        '''
        if process_state:
            processed_state = self.process_state(state)
        else:
            processed_state = state

        with torch.no_grad():
            Q_s = self.target_network(processed_state)

        return Q_s

    def update_target(self,
                      state,
                      action,
                      reward,
                      next_state,
                      terminal,
                      process_state=True):
        '''
        Calculates the update target for the state, action, reward, next_state, terminal tuple
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        action: int, provided by the environment
            index of the action taken by the agent
            
        reward: float, provided by the environment
            reward for the given state, action, next state transition
            
        next_state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        
        terminal: bool, usually provided by the environment
            whether or not the current episode has ended
            
        process_state: bool
            whether to process the state before estimating Q_s_next
        '''
        Q_s_next = self.estimate_target_q(next_state,
                                          process_state=process_state)
        terminal_mask = torch.tensor([not t for t in terminal],
                                     dtype=torch.float32)
        update_target = reward + self.gamma * torch.max(
            Q_s_next, dim=1)[0] * terminal_mask

        return update_target

    def update(self):
        '''
        Updates the model by taking a step from the optimizer
        The version does not include gradient clipping
        '''
        if self.current_step >= self.replay_start:
            if self.current_step % self.target_update_freq == 0:
                self.assign_variables()

            if self.current_step % self.update_freq == 0:
                experiences = self.replay_buffer.sample(self.batch_size)
                state, action, reward, next_state, terminal = zip(*experiences)

                state = torch.cat(state)
                action = torch.tensor(action, dtype=torch.int64)
                reward = torch.tensor(reward, dtype=torch.float32)
                next_state = torch.cat(next_state)

                update_target = self.update_target(state,
                                                   action,
                                                   reward,
                                                   next_state,
                                                   terminal,
                                                   process_state=False)
                Q_sa = self.online_network(state).gather(
                    1, action.unsqueeze(1)).squeeze()

                loss = self.loss_function(Q_sa, update_target)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        self.current_step += 1

    def update_epsilon(self):
        '''
        Decays epsilon by the decay rate
        '''
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)