Exemplo n.º 1
0
 def __init__(self, n_agents, model_fn, n_quantiles,
              action_scale = 1.0,
              gamma = 0.99,
              exploration_noise_fn = None,
              batch_size = 64,
              replay_memory = 100000,
              replay_start = 100,
              tau = 1e-3,
              optimizer = optim.Adam,
              actor_learning_rate = 1e-4,
              critic_learning_rate = 1e-3,
              clip_gradients = None,
              share_weights = False,
              action_repeat = 1,
              update_freq = 1,
              random_seed = None):
     # create online and target networks for each agent
     self.n_agents = n_agents
     
     self.online_networks = [model_fn() for _ in range(self.n_agents)]
     self.target_networks = [model_fn() for _ in range(self.n_agents)]
     
     self.actor_optimizers = [optimizer(agent.actor_params, 
                                        lr = actor_learning_rate) for agent in self.online_networks]
     self.critic_optimizers = [optimizer(agent.critic_params, 
                                         lr = critic_learning_rate) for agent in self.online_networks]
     
     if exploration_noise_fn:
         self.exploration_noise = [exploration_noise_fn() for _ in range(self.n_agents)]
     else:
         self.exploration_noise = None
         
     self.n_quantiles = n_quantiles
     self.cumulative_density = torch.tensor((2 * np.arange(self.n_quantiles) + 1) / (2.0 * self.n_quantiles), 
                                            dtype = torch.float32).view(1, -1)
                                      
     # assign the online network variables to the target network
     for target_network, online_network in zip(self.target_networks, self.online_networks):
         target_network.load_state_dict(online_network.state_dict())
     
     self.replay_buffer = ReplayBuffer(memory_size = replay_memory, seed = random_seed)
     
     self.share_weights = share_weights
     self.action_scale = action_scale
     self.gamma = gamma
     self.tau = tau
     self.batch_size = batch_size
     self.clip_gradients = clip_gradients
     self.replay_start = replay_start
     self.action_repeat = action_repeat
     self.update_freq = update_freq
     self.random_seed = random_seed
     
     self.reset_current_step()
Exemplo n.º 2
0
    def __init__(self,
                 state_size,
                 action_size,
                 buffer_size,
                 batch_size,
                 gamma,
                 tau,
                 lr,
                 hidden_1,
                 hidden_2,
                 update_every,
                 epsilon,
                 epsilon_min,
                 eps_decay,
                 seed
                 ):
        """Initialize an Agent object.
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.update_every = update_every
        self.seed = random.seed(seed)
        self.learn_steps = 0
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.eps_decay = eps_decay

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, seed, hidden_1, hidden_2).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed, hidden_1, hidden_2).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)

        # Replay memory
        self.memory = ReplayBuffer(self.action_size, self.buffer_size, self.batch_size, self.seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
Exemplo n.º 3
0
    def __init__(self,
                 model_fn,
                 action_scale=1.0,
                 gamma=0.99,
                 exploration_noise=None,
                 batch_size=64,
                 replay_memory=100000,
                 replay_start=1000,
                 tau=1e-3,
                 optimizer=optim.Adam,
                 actor_learning_rate=1e-4,
                 critic_learning_rate=1e-3,
                 clip_gradients=None,
                 action_repeat=1,
                 update_freq=1,
                 random_seed=None):
        # create online and target networks
        self.online_network = model_fn()
        self.target_network = model_fn()

        # create the optimizers for the online_network
        self.actor_optimizer = optimizer(self.online_network.actor_params,
                                         lr=actor_learning_rate)
        self.critic_optimizer = optimizer(self.online_network.critic_params,
                                          lr=critic_learning_rate)
        self.clip_gradients = clip_gradients

        # assign the online network variables to the target network
        self.target_network.load_state_dict(self.online_network.state_dict())

        self.replay_buffer = ReplayBuffer(memory_size=replay_memory,
                                          seed=random_seed)
        self.exploration_noise = exploration_noise

        self.action_scale = action_scale
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.action_repeat = action_repeat
        self.update_freq = update_freq
        self.random_seed = random_seed

        self.reset_current_step()
    def __init__(
        self,
        buffer_size,
        seed,
        state_size,
        action_size,
        hidden_layers,
        epsilon,
        epsilon_decay,
        epsilon_min,
        gamma,
        tau,
        learning_rate,
        update_frequency,
        double_Q=False,
        prioritised_replay_buffer=False,
        alpha=None,
        beta=None,
        beta_increment_size=None,
        base_priority=None,
        max_priority=None,
        training_scores=None,
        step_number=0,
    ):
        """DQNAgent initialisation function.

        Args:

            buffer_size (int): maximum size of the replay buffer.
            seed (int): random seed used for batch selection.

            state_size (int): dimension of state space for input to Q network.
            action_size (int): dimension of action space for value predictions.
            hidden_layers (list[int]): list of dimensions for the hidden layers required.

            epsilon (float): probability of choosing non-greedy action in policy.
            epsilon_decay (float): linear decay rate of epsilon with after each step.
            epsilon_min (float): a floor for the decay of epsilon.

            gamma (float): discount factor for future expected returns.
            tau (float): soft update factor used to define how much to shift.
                       target network parameters towards current network parameter.

            learning_rate (float): learning rate for gradient decent optimisation.
            update_frequency (int): how often to update target Q network parameters.

            double_Q (bool): set true to train using double deep Q learning.

            priority_replay_buffer (bool): set true to use priority replay buffer.
            alpha (float): priority scaling hyperparameter.
            beta_zero (float): importance sampling scaling hyperparameter.
            beta_increment_size (float): beta annealing rate.
            base_priority (float): base priority to ensure non-zero sampling probability.
            max_priority (float): initial maximum priority.

            training_scores (list[int]): rewards gained in previous traing episodes. (this is primarily 
                                used to reloading saved agents)
            step_number (int): number of steps the agent has taken. (this is primarily 
                                used to reloading saved agents)

        Notes: Setting tau = 1 will return classic DQN with full target update.
               If using soft updates it is recommended that update frequency is high. 
        """

        self.buffer_size = buffer_size
        self.seed = seed
        if prioritised_replay_buffer:
            self.replay_buffer = PrioritisedReplayBuffer(
                buffer_size,
                alpha,
                beta,
                beta_increment_size,
                base_priority,
                max_priority,
                seed,
            )
        else:
            self.replay_buffer = ReplayBuffer(buffer_size, seed)

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_layers = hidden_layers

        self.Q_net = QNetwork(state_size, action_size, hidden_layers).to(device)
        self.target_Q = QNetwork(state_size, action_size, hidden_layers).to(device)
        self.optimizer = optim.Adam(self.Q_net.parameters(), lr=learning_rate)

        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.gamma = gamma
        self.tau = tau

        self.learning_rate = learning_rate
        self.update_frequency = update_frequency

        self.double_Q = double_Q

        self.prioritised_replay_buffer = prioritised_replay_buffer
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_size = beta_increment_size
        self.base_priority = base_priority
        self.max_priority = max_priority

        self.step_number = step_number
        if training_scores is None:
            self.training_scores = []
        else:
            self.training_scores = training_scores
class DQNAgent():
    """A deep Q network agent.

    An agent using a deep Q network with a replay buffer, soft target update
    and linear epsilon decay that can learn to solve a task by interacting with
    its environment. Includes option to train using double deep Q learning
    and prioritised replay buffer.
    """

    def __init__(
        self,
        buffer_size,
        seed,
        state_size,
        action_size,
        hidden_layers,
        epsilon,
        epsilon_decay,
        epsilon_min,
        gamma,
        tau,
        learning_rate,
        update_frequency,
        double_Q=False,
        prioritised_replay_buffer=False,
        alpha=None,
        beta=None,
        beta_increment_size=None,
        base_priority=None,
        max_priority=None,
        training_scores=None,
        step_number=0,
    ):
        """DQNAgent initialisation function.

        Args:

            buffer_size (int): maximum size of the replay buffer.
            seed (int): random seed used for batch selection.

            state_size (int): dimension of state space for input to Q network.
            action_size (int): dimension of action space for value predictions.
            hidden_layers (list[int]): list of dimensions for the hidden layers required.

            epsilon (float): probability of choosing non-greedy action in policy.
            epsilon_decay (float): linear decay rate of epsilon with after each step.
            epsilon_min (float): a floor for the decay of epsilon.

            gamma (float): discount factor for future expected returns.
            tau (float): soft update factor used to define how much to shift.
                       target network parameters towards current network parameter.

            learning_rate (float): learning rate for gradient decent optimisation.
            update_frequency (int): how often to update target Q network parameters.

            double_Q (bool): set true to train using double deep Q learning.

            priority_replay_buffer (bool): set true to use priority replay buffer.
            alpha (float): priority scaling hyperparameter.
            beta_zero (float): importance sampling scaling hyperparameter.
            beta_increment_size (float): beta annealing rate.
            base_priority (float): base priority to ensure non-zero sampling probability.
            max_priority (float): initial maximum priority.

            training_scores (list[int]): rewards gained in previous traing episodes. (this is primarily 
                                used to reloading saved agents)
            step_number (int): number of steps the agent has taken. (this is primarily 
                                used to reloading saved agents)

        Notes: Setting tau = 1 will return classic DQN with full target update.
               If using soft updates it is recommended that update frequency is high. 
        """

        self.buffer_size = buffer_size
        self.seed = seed
        if prioritised_replay_buffer:
            self.replay_buffer = PrioritisedReplayBuffer(
                buffer_size,
                alpha,
                beta,
                beta_increment_size,
                base_priority,
                max_priority,
                seed,
            )
        else:
            self.replay_buffer = ReplayBuffer(buffer_size, seed)

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_layers = hidden_layers

        self.Q_net = QNetwork(state_size, action_size, hidden_layers).to(device)
        self.target_Q = QNetwork(state_size, action_size, hidden_layers).to(device)
        self.optimizer = optim.Adam(self.Q_net.parameters(), lr=learning_rate)

        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.gamma = gamma
        self.tau = tau

        self.learning_rate = learning_rate
        self.update_frequency = update_frequency

        self.double_Q = double_Q

        self.prioritised_replay_buffer = prioritised_replay_buffer
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_size = beta_increment_size
        self.base_priority = base_priority
        self.max_priority = max_priority

        self.step_number = step_number
        if training_scores is None:
            self.training_scores = []
        else:
            self.training_scores = training_scores

    def step(self, state, action, reward, next_state, done, batch_size):
        """
        A function that records experiences into the replay buffer after each
        environment step, then update the current network parameter and soft
        updates target network parameters.
        """
        self.replay_buffer.add(state, action, reward, next_state, done)
        self.update_Q(batch_size)
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

        self.step_number += 1
        if self.step_number % self.update_frequency == 0:
            self.soft_update_target_Q()

    def act_epsilon_greedy(self, state, greedy=False):
        """ Returns an epsilon greedy action """
        if greedy or random.random() > self.epsilon:
            state = torch.from_numpy(state).unsqueeze(0).to(device)
            self.Q_net.eval()
            with torch.no_grad():
                action_values = self.Q_net.forward(state)
            self.Q_net.train()
            return torch.argmax(action_values).cpu().item()

        return np.random.randint(self.action_size)

    def update_Q(self, batch_size):
        """
        Updates the parameters of the current Q network using backpropagation
        and experiences from the replay buffer.
        """

        if len(self.replay_buffer) > 2*batch_size:

            experience = self.replay_buffer.sample(batch_size)

            states = torch.FloatTensor(experience[0]).to(device)
            actions = torch.LongTensor(experience[1]).unsqueeze(1).to(device)
            rewards = torch.FloatTensor(experience[2]).unsqueeze(1).to(device)
            next_states = torch.FloatTensor(experience[3]).to(device)
            done_tensor = torch.FloatTensor(experience[4]).unsqueeze(1).to(device)

            target_Q_net_max = torch.max(self.target_Q(next_states).detach(), 1, keepdim=True)

            if self.double_Q:
                target_actions = target_Q_net_max[1]
                Q_target_next = self.Q_net(next_states).detach().gather(1, target_actions)
            else:
                Q_target_next = target_Q_net_max[0]

            Q_expected = self.Q_net(states).gather(1, actions)
            Q_target = rewards + self.gamma * Q_target_next * (1 - done_tensor)

            if self.prioritised_replay_buffer:
                idx_list = experience[5]
                weights = torch.FloatTensor(experience[6]).unsqueeze(1).to(device)
                td_error = (Q_target - Q_expected)
                priority_list = torch.abs(td_error.squeeze().detach()).cpu().numpy()
                self.replay_buffer.update(idx_list, priority_list)
                loss = torch.mean((weights*td_error)**2)
            else:
                loss = F.mse_loss(Q_expected, Q_target)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def soft_update_target_Q(self):
        """ Soft updates the target Q network """
        for target_Q_param, Q_param in zip(self.target_Q.parameters(), self.Q_net.parameters()):
            target_Q_param.data = self.tau * Q_param.data + (1 - self.tau) * target_Q_param.data

    def save_agent(self, name, path=""):
        """ Saves agent parameters for loading using load_agent function
        Note: it is torch convention to save models with .pth extension
        """
        params = (
            self.buffer_size,
            self.seed,
            self.state_size,
            self.action_size,
            self.hidden_layers,
            self.epsilon,
            self.epsilon_decay,
            self.epsilon_min,
            self.gamma,
            self.tau,
            self.learning_rate,
            self.update_frequency,
            self.double_Q,
            self.prioritised_replay_buffer,
            self.alpha,
            self.replay_buffer.beta,
            self.beta_increment_size,
            self.base_priority,
            self.max_priority,
            self.training_scores,
            self.step_number,
        )

        checkpoint = {
            "params": params,
            "state_dict": self.Q_net.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }

        path += name
        torch.save(checkpoint, path)
Exemplo n.º 6
0
def td3(env_fn,
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=100,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        pi_lr=1e-3,
        q_lr=1e-3,
        batch_size=100,
        start_steps=10000,
        update_after=1000,
        update_every=50,
        act_noise=0.1,
        target_noise=0.2,
        noise_clip=0.5,
        policy_delay=2,
        num_test_episodes=100,
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=1,
        multistep_n=1,
        use_parameter_noise=False,
        decay_exploration=False,
        save_k_latest=1):
    """
    Twin Delayed Deep Deterministic Policy Gradient (TD3)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act`` 
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of 
            observations as inputs, and ``q1`` and ``q2`` should accept a batch 
            of observations and a batch of actions as inputs. When called, 
            these should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each 
                                           | observation.
            ``pi``       (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current 
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object 
            you provided to TD3.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        pi_lr (float): Learning rate for policy.

        q_lr (float): Learning rate for Q-networks.

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long 
            you wait between updates, the ratio of env steps to gradient steps 
            is locked to 1.

        act_noise (float): Stddev for Gaussian exploration noise added to 
            policy at training time. (At test time, no noise is added.)

        target_noise (float): Stddev for smoothing noise added to target 
            policy.

        noise_clip (float): Limit for absolute value of target policy 
            smoothing noise.

        policy_delay (int): Policy will only be updated once every 
            policy_delay times for each update of the Q-networks.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    sigma = act_noise

    # Create actor-critic module and target networks
    ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    ac_targ = deepcopy(ac)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer
    if multistep_n > 1:
        replay_buffer = MultiStepReplayBuffer(obs_dim=obs_dim,
                                              act_dim=act_dim,
                                              size=replay_size,
                                              n=multistep_n,
                                              gamma=gamma)
    else:
        replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                     act_dim=act_dim,
                                     size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(
        core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' %
               var_counts)

    # Set up function for computing TD3 Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data[
            'obs2'], data['done']

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            pi_targ = ac_targ.pi(o2)

            # Target policy smoothing
            epsilon = torch.randn_like(pi_targ) * target_noise
            epsilon = torch.clamp(epsilon, -noise_clip, noise_clip)
            a2 = pi_targ + epsilon
            a2 = torch.clamp(a2, -act_limit, act_limit)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma**(multistep_n) * (1 - d) * q_pi_targ

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        loss_info = dict(Q1Vals=q1.detach().numpy(),
                         Q2Vals=q2.detach().numpy())

        return loss_q, loss_info

    # Set up function for computing TD3 pi loss
    def compute_loss_pi(data):
        o = data['obs']
        q1_pi = ac.q1(o, ac.pi(o))
        return -q1_pi.mean()

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    q_optimizer = Adam(q_params, lr=q_lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update(data, timer):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, loss_info = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()

        # Record things
        logger.store(LossQ=loss_q.item(), **loss_info)

        # Possibly update pi and target networks
        if timer % policy_delay == 0:

            # Freeze Q-networks so you don't waste computational effort
            # computing gradients for them during the policy learning step.
            for p in q_params:
                p.requires_grad = False

            # Next run one gradient descent step for pi.
            pi_optimizer.zero_grad()
            loss_pi = compute_loss_pi(data)
            loss_pi.backward()
            pi_optimizer.step()

            # Unfreeze Q-networks so you can optimize it at next DDPG step.
            for p in q_params:
                p.requires_grad = True

            # Record things
            logger.store(LossPi=loss_pi.item())

            # Finally, update target networks by polyak averaging.
            with torch.no_grad():
                for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                    # NB: We use an in-place operations "mul_", "add_" to update target
                    # params, as opposed to "mul" and "add", which would make new tensors.
                    p_targ.data.mul_(polyak)
                    p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, noise_scale):
        a = ac.act(torch.as_tensor(o, dtype=torch.float32))
        noise_dim = None if act_dim == 1 else act_dim  # Fixes a bug that occurs in case of 1-dimensional action spaces
        a += noise_scale * np.random.standard_normal(noise_dim)
        return np.clip(a, -act_limit, act_limit)

    def get_action_from_perturbed_model(o, actor):
        a = actor.act(torch.as_tensor(o, dtype=torch.float32))
        return np.clip(a, -act_limit, act_limit)

    def test_agent():
        for j in range(num_test_episodes):
            if type(env).__class__.__name__ == 'CurriculumEnv':
                o, d, ep_ret, ep_len = test_env.reset(opponent="strong",
                                                      mode=0), False, 0, 0
            else:
                o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0

            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, d, _ = test_env.step(get_action(o, 0))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
        return ep_ret

    def get_perturbed_model(sigma):
        actor = deepcopy(ac)
        with torch.no_grad():
            for param in actor.pi.parameters():
                param.add_(torch.randn(param.size()) * sigma)
        return actor

    def update_sigma(target_noise,
                     sigma,
                     actor,
                     perturbed_actor,
                     batch_size=128):
        obs = replay_buffer.sample_batch(batch_size=batch_size)['obs']
        with torch.no_grad():
            ac1 = actor.pi(obs)
            ac2 = perturbed_actor.pi(obs)
            dist = torch.sqrt(torch.mean((ac1 - ac2)**2))
        if dist < target_noise:
            sigma *= 1.01
        else:
            sigma /= 1.01
        return sigma

    # Prepare for interaction with environment
    best_test_ep_return = float('-inf')
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
    save_id = 0

    if use_parameter_noise:
        perturbed_actor = get_perturbed_model(sigma)

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy (with some noise, via act_noise).
        if t > start_steps:
            if use_parameter_noise:
                a = get_action_from_perturbed_model(o.squeeze(),
                                                    perturbed_actor)
            else:
                a = get_action(o.squeeze(), act_noise)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, info = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch, timer=j)
                if use_parameter_noise:
                    perturbed_actor = get_perturbed_model(
                        sigma
                    )  # To make sure that perturbed actor is based on the same actor that we are testing against
                    sigma = update_sigma(act_noise, sigma, ac, perturbed_actor)
                    perturbed_actor = get_perturbed_model(sigma)

        # End of epoch handling
        if (t + 1) % steps_per_epoch == 0:

            # Decay the action noise or sigma
            if decay_exploration and (((t + 1) % steps_per_epoch) == 0):
                sigma /= 1.001
                act_noise /= 1.001

            epoch = (t + 1) // steps_per_epoch

            # Test the performance of the deterministic version of the agent.
            test_ep_return = test_agent()

            # Save best model:
            if test_ep_return > best_test_ep_return:
                best_test_ep_return = test_ep_return
                logger.save_state({'env': env}, 0)

            # Save latest model
            # if (epoch % save_freq == 0) or (epoch == epochs):
            save_id += 1
            logger.save_state({'env': env}, save_id % save_k_latest)

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            if use_parameter_noise:
                logger.log_tabular('sigma', sigma)
            if 'stage' in info.keys():
                logger.log_tabular('CStage', info['stage'])
            logger.dump_tabular()
Exemplo n.º 7
0
    return loss


# Params
num_quant = 51
render = True

# Here we've defined a schedule for exploration i.e. random action with prob eps
eps_start, eps_end, eps_dec = 0.9, 0.1, 500

env = gym.make('MountainCar-v0')
obs_dim = env.observation_space.shape[0]
action_dim = 1
action_num = env.action_space.n

memory = ReplayBuffer(10000, obs_dim, action_dim)

Z = Network(obs_dim, num_quant, action_num)
Ztgt = Network(obs_dim, num_quant, action_num)
Ztgt.model.set_weights(Z.model.get_weights())
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

tau = tf.constant(np.reshape(
    (2 * np.arange(Z.num_quant) + 1) / (2.0 * Z.num_quant), (-1, 1, 1)),
                  dtype=tf.float32)

gamma, batch_size = 0.99, 32
steps_done, running_reward = 0, 0
batch_range = tf.reshape(tf.range(0, batch_size), (-1, 1))

for episode in range(500):
Exemplo n.º 8
0
class DDPGAgent():
    def __init__(self,
                 model_fn,
                 action_scale=1.0,
                 gamma=0.99,
                 exploration_noise=None,
                 batch_size=64,
                 replay_memory=100000,
                 replay_start=1000,
                 tau=1e-3,
                 optimizer=optim.Adam,
                 actor_learning_rate=1e-4,
                 critic_learning_rate=1e-3,
                 clip_gradients=None,
                 action_repeat=1,
                 update_freq=1,
                 random_seed=None):
        # create online and target networks
        self.online_network = model_fn()
        self.target_network = model_fn()

        # create the optimizers for the online_network
        self.actor_optimizer = optimizer(self.online_network.actor_params,
                                         lr=actor_learning_rate)
        self.critic_optimizer = optimizer(self.online_network.critic_params,
                                          lr=critic_learning_rate)
        self.clip_gradients = clip_gradients

        # assign the online network variables to the target network
        self.target_network.load_state_dict(self.online_network.state_dict())

        self.replay_buffer = ReplayBuffer(memory_size=replay_memory,
                                          seed=random_seed)
        self.exploration_noise = exploration_noise

        self.action_scale = action_scale
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.action_repeat = action_repeat
        self.update_freq = update_freq
        self.random_seed = random_seed

        self.reset_current_step()

    def reset_current_step(self):
        self.current_step = 0

    def soft_update(self):
        for target_param, online_param in zip(
                self.target_network.parameters(),
                self.online_network.parameters()):
            target_param.detach_()
            target_param.copy_(target_param * (1.0 - self.tau) +
                               online_param * self.tau)

    def add_to_replay_memory(self, state, action, reward, next_state,
                             terminal):
        experience = (state, action, reward, next_state, terminal)
        self.replay_buffer.add(experience)

    def action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

        if (self.current_step % self.action_repeat
                == 0) or (not hasattr(self, '_previous_action')):
            action = self.online_network(state)
            action = action.squeeze().detach().numpy()

            if self.exploration_noise is not None:
                action = action + self.exploration_noise.sample()
                action = np.clip(action, -self.action_scale, self.action_scale)
        else:
            action = self._previous_action

        self._previous_action = action

        return action

    def update_target(self, state, action, reward, next_state, terminal):
        next_action = self.target_network(next_state).detach()
        Q_sa_next = self.target_network.critic_value(next_state,
                                                     next_action).detach()

        update_target = reward.unsqueeze(
            -1) + self.gamma * Q_sa_next * (1 - terminal).unsqueeze(-1)
        update_target = torch.tensor(update_target, dtype=torch.float32)

        return update_target

    def update(self, state, action, reward, next_state, terminal):
        self.add_to_replay_memory(state, action, reward, next_state, terminal)

        if terminal and (self.exploration_noise is not None):
            try:
                self.exploration_noise[i].reset_states()
            except:
                pass

        if self.current_step >= self.replay_start:
            if self.current_step % self.update_freq == 0:
                experiences = self.replay_buffer.sample(self.batch_size)
                state, action, reward, next_state, terminal = zip(*experiences)

                state = torch.tensor(state, dtype=torch.float32)
                action = torch.tensor(action, dtype=torch.float32)
                reward = torch.tensor(reward, dtype=torch.float32)
                next_state = torch.tensor(next_state, dtype=torch.float32)
                terminal = torch.tensor(terminal, dtype=torch.float32)

                update_target = self.update_target(state, action, reward,
                                                   next_state, terminal)
                Q_sa = self.online_network.critic_value(state, action)
                critic_loss = (Q_sa -
                               update_target).pow(2).mul(0.5).sum(-1).mean()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                if self.clip_gradients:
                    nn.utils.clip_grad_norm_(self.online_network.critic_params,
                                             self.clip_gradients)
                self.critic_optimizer.step()

                action = self.online_network(state)
                policy_loss = -self.online_network.critic_value(state,
                                                                action).mean()

                self.actor_optimizer.zero_grad()
                policy_loss.backward()
                if self.clip_gradients:
                    nn.utils.clip_grad_norm_(self.online_network.actor_params,
                                             self.clip_gradients)
                self.actor_optimizer.step()

                self.soft_update()

        self.current_step += 1
Exemplo n.º 9
0
class DQNAgent():
    """Interacts with and learns from the environment."""

    def __init__(self,
                 state_size,
                 action_size,
                 buffer_size,
                 batch_size,
                 gamma,
                 tau,
                 lr,
                 hidden_1,
                 hidden_2,
                 update_every,
                 epsilon,
                 epsilon_min,
                 eps_decay,
                 seed
                 ):
        """Initialize an Agent object.
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.update_every = update_every
        self.seed = random.seed(seed)
        self.learn_steps = 0
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.eps_decay = eps_decay

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, seed, hidden_1, hidden_2).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed, hidden_1, hidden_2).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)

        # Replay memory
        self.memory = ReplayBuffer(self.action_size, self.buffer_size, self.batch_size, self.seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0


    def step(self, state, action, reward, next_state,  done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            # Sample if enough samples are available
            if len(self.memory) > self.batch_size:
                experiences = self.memory.sample()
                self.learn(experiences)

    def act(self, state):
        """Returns actions for given state as per current policy.
        Params
        ======
            state (array_like): current state
        """
        self.epsilon = max(self.epsilon*self.eps_decay, self.epsilon_min)

        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > self.epsilon:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.learn_steps += 1

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
Exemplo n.º 10
0
class MADDPGAgent():
    def __init__(self, n_agents, model_fn,
                 action_scale = 1.0,
                 gamma = 0.99,
                 exploration_noise_fn = None,
                 batch_size = 64,
                 replay_memory = 100000,
                 replay_start = 100,
                 tau = 1e-3,
                 optimizer = optim.Adam,
                 actor_learning_rate = 1e-4,
                 critic_learning_rate = 1e-3,
                 clip_gradients = None,
                 share_weights = False,
                 action_repeat = 1,
                 update_freq = 1,
                 random_seed = None):
        # create online and target networks for each agent
        self.n_agents = n_agents
        
        self.online_networks = [model_fn() for _ in range(self.n_agents)]
        self.target_networks = [model_fn() for _ in range(self.n_agents)]
        
        self.actor_optimizers = [optimizer(agent.actor_params, 
                                           lr = actor_learning_rate) for agent in self.online_networks]
        self.critic_optimizers = [optimizer(agent.critic_params, 
                                            lr = critic_learning_rate) for agent in self.online_networks]
        
        if exploration_noise_fn:
            self.exploration_noise = [exploration_noise_fn() for _ in range(self.n_agents)]
        else:
            self.exploration_noise = None
                                         
        # assign the online network variables to the target network
        for target_network, online_network in zip(self.target_networks, self.online_networks):
            target_network.load_state_dict(online_network.state_dict())
        
        self.replay_buffer = ReplayBuffer(memory_size = replay_memory, seed = random_seed)
        
        self.share_weights = share_weights
        self.action_scale = action_scale
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.clip_gradients = clip_gradients
        self.replay_start = replay_start
        self.action_repeat = action_repeat
        self.update_freq = update_freq
        self.random_seed = random_seed
        
        self.reset_current_step()
        
    def reset_current_step(self):
        self.current_step = 0
        
    def soft_update(self):
        for target_network, online_network in zip(self.target_networks, self.online_networks):
            for target_param, online_param in zip(target_network.parameters(), online_network.parameters()):
                target_param.detach_()
                target_param.copy_(target_param * (1.0 - self.tau) + online_param * self.tau)
                
    def assign_weights(self):
        for target_network, online_network in zip(self.target_networks, self.online_networks):
            target_network.load_state_dict(self.target_networks[0].state_dict())
            online_network.load_state_dict(self.online_networks[0].state_dict())
    
    def add_to_replay_memory(self, state, action, reward, next_state, terminal):
        experience = (state, action, reward, next_state, terminal)
        self.replay_buffer.add(experience)
    
    def action(self, state):
        if (self.current_step % self.action_repeat == 0) or (not hasattr(self, '_previous_action')):
            actions = []
            for i in range(self.n_agents):
                obs = torch.tensor(state[i], dtype = torch.float32).unsqueeze(0)
                action = self.online_networks[i](obs)
                action = action.squeeze().detach().numpy()
                
                if self.exploration_noise:
                    action = action + self.exploration_noise[i].sample()
                    action = np.clip(action, -self.action_scale, self.action_scale)
                actions.append(action)
                
            action = np.asarray(actions)
        else:
            action = self._previous_action
                
        self._previous_action = action

        return action
    
    def update_target(self, state, action, reward, next_state, terminal):
        all_next_actions = []
        for i in range(self.n_agents):
            next_action = self.target_networks[i](next_state[:, i, :])
            all_next_actions.append(next_action)

        all_next_actions = torch.cat(all_next_actions, dim = 1)
        all_next_states = next_state.view(-1, next_state.shape[1] * next_state.shape[2])

        Q_sa_next = self.target_networks[self.current_agent].critic_value(all_next_states, all_next_actions)

        reward = reward[:, self.current_agent].unsqueeze(-1)
        terminal = terminal[:, self.current_agent].unsqueeze(-1)
 
        update_target = reward + self.gamma * Q_sa_next * (1 - terminal)
        update_target = update_target.detach()
        
        return update_target
    
    def update(self, state, action, reward, next_state, terminal):
        self.add_to_replay_memory(state, action, reward, next_state, terminal)
        
        if np.any(terminal) and (self.exploration_noise is not None):
            for i in range(self.n_agents):
                try:
                    self.exploration_noise[i].reset_states()
                except:
                    pass
        
        if self.current_step >= self.replay_start:
            if self.current_step % self.update_freq == 0:
                if self.share_weights:
                    update_agents = 1
                else:
                    update_agents = self.n_agents

                for i in range(update_agents):
                    self.current_agent = i
                
                    experiences = self.replay_buffer.sample(self.batch_size)     
                    state, action, reward, next_state, terminal = zip(*experiences)
                    
                    state = torch.tensor(state, dtype = torch.float32)
                    action = torch.tensor(action, dtype = torch.float32)
                    reward = torch.tensor(reward, dtype = torch.float32)
                    next_state = torch.tensor(next_state, dtype = torch.float32)
                    terminal = torch.tensor(terminal, dtype = torch.float32)

                    all_actions = action.view(-1, action.shape[1] * action.shape[2])
                    all_states = state.view(-1, state.shape[1] * state.shape[2])
                    
                    update_target = self.update_target(state, action, reward, next_state, terminal)
                    
                    Q_sa = self.online_networks[self.current_agent].critic_value(all_states, all_actions)
                    critic_loss = F.mse_loss(Q_sa, update_target)
                    
                    self.critic_optimizers[self.current_agent].zero_grad()
                    critic_loss.backward()
                    if self.clip_gradients:
                        nn.utils.clip_grad_norm_(self.online_networks[self.current_agent].critic_params, self.clip_gradients)
                    self.critic_optimizers[self.current_agent].step()

                    agent_action = self.online_networks[self.current_agent](state[:, self.current_agent, :])

                    predicted_actions = action.clone().detach()
                    predicted_actions[:, self.current_agent] = agent_action
                    predicted_actions = predicted_actions.view(-1, predicted_actions.shape[1] * predicted_actions.shape[2])

                    policy_loss = -self.online_networks[self.current_agent].critic_value(all_states, predicted_actions).mean()

                    self.actor_optimizers[self.current_agent].zero_grad()
                    policy_loss.backward()
                    if self.clip_gradients:
                        nn.utils.clip_grad_norm_(self.online_networks[self.current_agent].actor_params, self.clip_gradients)
                    self.actor_optimizers[self.current_agent].step()
                    
                self.soft_update()
                
                if self.share_weights:
                    self.assign_weights()
            
        self.current_step += 1
# copy weights
q_target.set_weights(q_function.get_weights())

valid_actions = [0, 1, 2]
centered_actions = np.fromiter(map(center_action, valid_actions), dtype=int)

action_selection_strategy = ActionSelection(
    action_bounds=(min(centered_actions), max(centered_actions)), max_steps=200
)
action_selection_strategy.reset(ou_sigma, tau=10)
action_selection_strategy.plot_noise()

if use_per:
    memory = PrioritizedReplayBuffer(buffer_size, obs_dim, action_dim)
else:
    memory = ReplayBuffer(buffer_size, obs_dim, action_dim)


normalizer = Normalize(obs_dim)

function_plotter = PlotFunction([-4, 4, -4, 4])

good_policy = np.concatenate((-np.ones(20), np.ones(40), -np.ones(60), np.ones(100))).astype(int)

plot_period = 10
steps_done = 0
episode = -1
while steps_done < max_steps:
    episode += 1
    state = env.reset()
    state = state.astype(np.float32)
Exemplo n.º 12
0
    def __init__(self,
                 model,
                 model_params,
                 state_processor,
                 n_actions,
                 gamma=0.99,
                 epsilon=1.0,
                 min_epsilon=1e-2,
                 epsilon_decay=.999,
                 loss_function=F.smooth_l1_loss,
                 optimizer=optim.Adam,
                 learning_rate=1e-3,
                 l2_regularization=0.0,
                 batch_size=32,
                 replay_memory=1000000,
                 replay_start=50000,
                 target_update_freq=1000,
                 action_repeat=4,
                 update_freq=4,
                 random_seed=None):
        '''
        DQN Agent from https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
        
        model: pytorch model class
            callable model class for the online and target networks
            the DQN agent instantiates each model
            
        model_params: dict
            dictionary of parameters used to define the model class (e.g., feature 
            space size, action space size, etc.)
            this should be the only input to instantiate the model class
            
        state_processor: function
            callable function that takes state as the input and outputs the processed state
            to use as a feature for the model
            the processed states are stored as experiences in the replay buffer
            
        n_actions: int
            the number of actions the agent can perform
            
        gamma: float, [0.0, 1.0]
            discount rate parameter
            
        epsilon: float, [0.0, 1.0]
            epsilon used to compute the epsilon-greedy policy
            
        min_epsilon: float, [0.0, 1.0]
            minimun value for epsilon over all episodes
            
        epsilon_decay: float, (0.0, 1.0]
            rate at which to decay epsilon after each episodes
            1.0 corresponds to no decay
            
        loss_function: pytorch loss (usually the functional form)
            callable loss function that takes inputs, targets as positional arguments
            
        optimizer: pytorch optimizer
            callable optimizer that takes the learning rate as a parameter
            
        learning_rate: float
            learning rate for the optimizer
            
        l2_regularization: float
            hyperparameter for L2 regularization
            
        batch_size: int
            batch size parameter for training the online network
            
        replay_memory: int
            maximum size of the replay memory
            
        replay_start: int
            number of actions to take/experiences to store before beginning to train 
            the online network 
            this should be larger than the batch size to avoid the same experience
            showing up multiple times in the batch
            
        target_update_freq: int
            the frequency at which the target network is updated with the online
            network's weights
            
        action_repeat: int
            the number of times to repeat the same action
            
        update_freq: int
            the number of steps between each SGD (or other optimization) update
            
        seed: None or int
            random seed for the replay buffer
        '''
        self.n_actions = n_actions
        self.actions = np.arange(self.n_actions)

        self.state_processor = state_processor
        self.gamma = gamma
        self.epsilon = epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.target_update_freq = target_update_freq
        self.action_repeat = action_repeat
        self.update_freq = update_freq

        self.reset_current_step()

        self.replay_buffer = ReplayBuffer(memory_size=replay_memory,
                                          seed=random_seed)

        self.online_network = model(model_params)
        self.target_network = model(model_params)
        self.assign_variables()

        self.loss_function = loss_function
        self.optimizer = optimizer(self.online_network.parameters(),
                                   lr=learning_rate,
                                   weight_decay=l2_regularization)
Exemplo n.º 13
0
class DQNAgent():
    def __init__(self,
                 model,
                 model_params,
                 state_processor,
                 n_actions,
                 gamma=0.99,
                 epsilon=1.0,
                 min_epsilon=1e-2,
                 epsilon_decay=.999,
                 loss_function=F.smooth_l1_loss,
                 optimizer=optim.Adam,
                 learning_rate=1e-3,
                 l2_regularization=0.0,
                 batch_size=32,
                 replay_memory=1000000,
                 replay_start=50000,
                 target_update_freq=1000,
                 action_repeat=4,
                 update_freq=4,
                 random_seed=None):
        '''
        DQN Agent from https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
        
        model: pytorch model class
            callable model class for the online and target networks
            the DQN agent instantiates each model
            
        model_params: dict
            dictionary of parameters used to define the model class (e.g., feature 
            space size, action space size, etc.)
            this should be the only input to instantiate the model class
            
        state_processor: function
            callable function that takes state as the input and outputs the processed state
            to use as a feature for the model
            the processed states are stored as experiences in the replay buffer
            
        n_actions: int
            the number of actions the agent can perform
            
        gamma: float, [0.0, 1.0]
            discount rate parameter
            
        epsilon: float, [0.0, 1.0]
            epsilon used to compute the epsilon-greedy policy
            
        min_epsilon: float, [0.0, 1.0]
            minimun value for epsilon over all episodes
            
        epsilon_decay: float, (0.0, 1.0]
            rate at which to decay epsilon after each episodes
            1.0 corresponds to no decay
            
        loss_function: pytorch loss (usually the functional form)
            callable loss function that takes inputs, targets as positional arguments
            
        optimizer: pytorch optimizer
            callable optimizer that takes the learning rate as a parameter
            
        learning_rate: float
            learning rate for the optimizer
            
        l2_regularization: float
            hyperparameter for L2 regularization
            
        batch_size: int
            batch size parameter for training the online network
            
        replay_memory: int
            maximum size of the replay memory
            
        replay_start: int
            number of actions to take/experiences to store before beginning to train 
            the online network 
            this should be larger than the batch size to avoid the same experience
            showing up multiple times in the batch
            
        target_update_freq: int
            the frequency at which the target network is updated with the online
            network's weights
            
        action_repeat: int
            the number of times to repeat the same action
            
        update_freq: int
            the number of steps between each SGD (or other optimization) update
            
        seed: None or int
            random seed for the replay buffer
        '''
        self.n_actions = n_actions
        self.actions = np.arange(self.n_actions)

        self.state_processor = state_processor
        self.gamma = gamma
        self.epsilon = epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.target_update_freq = target_update_freq
        self.action_repeat = action_repeat
        self.update_freq = update_freq

        self.reset_current_step()

        self.replay_buffer = ReplayBuffer(memory_size=replay_memory,
                                          seed=random_seed)

        self.online_network = model(model_params)
        self.target_network = model(model_params)
        self.assign_variables()

        self.loss_function = loss_function
        self.optimizer = optimizer(self.online_network.parameters(),
                                   lr=learning_rate,
                                   weight_decay=l2_regularization)

    def assign_variables(self):
        '''
        Assigns the variables (weights and biases) of the online network to the target networl
        '''
        self.target_network.load_state_dict(self.online_network.state_dict())

    def reset_current_step(self):
        '''
        Set the current_step attribute to 0
        '''
        self.current_step = 0

    def process_state(self, state):
        '''
        Process the state provided by the environment into the feature used by the 
        online and target networks
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        '''
        processed_state = self.state_processor(state)

        return processed_state

    def add_to_replay_memory(self, state, action, reward, next_state,
                             terminal):
        '''
        Add the state, action, reward, next_state, terminal tuple to the replay buffer
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        action: int, provided by the environment
            index of the action taken by the agent
            
        reward: float, provided by the environment
            reward for the given state, action, next state transition
            
        next_state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        
        terminal: bool, usually provided by the environment
            whether or not the current episode has ended
        '''
        processed_state = self.process_state(state)
        processed_next_state = self.process_state(next_state)

        experience = (processed_state, action, reward, processed_next_state,
                      terminal)
        self.replay_buffer.add(experience)

    def action(self, state, mode='train'):
        '''
        Selects an action according to the greedy or epsilon-greedy policy
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        mode: 'train' or 'test'
            selects an action acording to the epsilon-greedy policy when set to 'train'
            selects an action acording to the greedy policy when set to 'test'
        '''
        if (self.current_step % self.action_repeat
                == 0) or (not hasattr(self, 'previous_action')):
            if mode == 'test':
                state_policy, action = self.greedy_policy(state)
            else:
                state_policy, action = self.epsilon_greedy_policy(state)
        else:
            action = self.previous_action

        self.previous_action = action

        return action

    def greedy_policy(self, state):
        '''
        Returns the greedy policy as a discrete probability distribution and the 
        greedy action
        All actions except the greedy action have probablity 0
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        '''
        Q_s = self.estimate_q(state, process_state=True)

        action = np.argmax(Q_s)
        policy = np.zeros(self.n_actions)
        policy[action] = 1.0

        return policy, action

    def epsilon_greedy_policy(self, state):
        '''
        Returns the epsilon-greedy policy as a discrete probability distribution and
        an action randomly selected according to the probability distribution
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        '''
        Q_s = self.estimate_q(state, process_state=True)

        policy = np.ones(self.n_actions) * self.epsilon / self.n_actions
        policy[np.argmax(
            Q_s)] = 1.0 - self.epsilon + self.epsilon / self.n_actions

        action = np.random.choice(self.actions, p=policy)

        return policy, action

    def estimate_q(self, state, process_state=True):
        '''
        Estimates the Q values for a given state and all actions from the online network
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        process_state: bool
            whether to process the state before estimating Q_s
        '''
        if process_state:
            processed_state = self.process_state(state)
        else:
            processed_state = state

        with torch.no_grad():
            Q_s = self.online_network(processed_state)

        return Q_s

    def estimate_target_q(self, state, process_state=True):
        '''
        Estimates the Q values for a given state and all actions from the target network
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        process_state: bool
            whether to process the state before estimating Q_s
        '''
        if process_state:
            processed_state = self.process_state(state)
        else:
            processed_state = state

        with torch.no_grad():
            Q_s = self.target_network(processed_state)

        return Q_s

    def update_target(self,
                      state,
                      action,
                      reward,
                      next_state,
                      terminal,
                      process_state=True):
        '''
        Calculates the update target for the state, action, reward, next_state, terminal tuple
        
        state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
            
        action: int, provided by the environment
            index of the action taken by the agent
            
        reward: float, provided by the environment
            reward for the given state, action, next state transition
            
        next_state: object, provided by the environment
            state provided by the environment, usually a vector or tensor
        
        terminal: bool, usually provided by the environment
            whether or not the current episode has ended
            
        process_state: bool
            whether to process the state before estimating Q_s_next
        '''
        Q_s_next = self.estimate_target_q(next_state,
                                          process_state=process_state)
        terminal_mask = torch.tensor([not t for t in terminal],
                                     dtype=torch.float32)
        update_target = reward + self.gamma * torch.max(
            Q_s_next, dim=1)[0] * terminal_mask

        return update_target

    def update(self):
        '''
        Updates the model by taking a step from the optimizer
        The version does not include gradient clipping
        '''
        if self.current_step >= self.replay_start:
            if self.current_step % self.target_update_freq == 0:
                self.assign_variables()

            if self.current_step % self.update_freq == 0:
                experiences = self.replay_buffer.sample(self.batch_size)
                state, action, reward, next_state, terminal = zip(*experiences)

                state = torch.cat(state)
                action = torch.tensor(action, dtype=torch.int64)
                reward = torch.tensor(reward, dtype=torch.float32)
                next_state = torch.cat(next_state)

                update_target = self.update_target(state,
                                                   action,
                                                   reward,
                                                   next_state,
                                                   terminal,
                                                   process_state=False)
                Q_sa = self.online_network(state).gather(
                    1, action.unsqueeze(1)).squeeze()

                loss = self.loss_function(Q_sa, update_target)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        self.current_step += 1

    def update_epsilon(self):
        '''
        Decays epsilon by the decay rate
        '''
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)
def train():
    from linear_schedule import Linear

    ledger = defaultdict(lambda: MovingAverage(Reporting.reward_average))

    M.config(file=os.path.join(RUN.log_directory, RUN.log_file))
    M.diff()

    with U.make_session(
            RUN.num_cpu), Logger(RUN.log_directory) as logger, contextify(
                gym.make(G.env_name)) as env:
        env = ScaledFloatFrame(wrap_dqn(env))

        if G.seed is not None:
            env.seed(G.seed)
        logger.log_params(G=vars(G), RUN=vars(RUN), Reporting=vars(Reporting))
        inputs = TrainInputs(action_space=env.action_space,
                             observation_space=env.observation_space)
        trainer = QTrainer(inputs=inputs,
                           action_space=env.action_space,
                           observation_space=env.observation_space)
        if G.prioritized_replay:
            replay_buffer = PrioritizedReplayBuffer(size=G.buffer_size,
                                                    alpha=G.alpha)
        else:
            replay_buffer = ReplayBuffer(size=G.buffer_size)

        class schedules:
            # note: it is important to have this start from the begining.
            eps = Linear(G.n_timesteps * G.exploration_fraction, 1,
                         G.final_eps)
            if G.prioritized_replay:
                beta = Linear(G.n_timesteps - G.learning_start, G.beta_start,
                              G.beta_end)

        U.initialize()
        trainer.update_target()
        x = np.array(env.reset())
        ep_ind = 0
        M.tic('episode')
        for t_step in range(G.n_timesteps):
            # schedules
            eps = 0 if G.param_noise else schedules.eps[t_step]
            if G.prioritized_replay:
                beta = schedules.beta[t_step - G.learning_start]

            x0 = x
            M.tic('sample', silent=True)
            (action, *_), action_q, q = trainer.runner.act([x], eps)
            x, rew, done, info = env.step(action)
            ledger['action_q_value'].append(action_q.max())
            ledger['action_q_value/mean'].append(action_q.mean())
            ledger['action_q_value/var'].append(action_q.var())
            ledger['q_value'].append(q.max())
            ledger['q_value/mean'].append(q.mean())
            ledger['q_value/var'].append(q.var())
            ledger['timing/sample'].append(M.toc('sample', silent=True))
            # note: adding sample to the buffer is identical between the prioritized and the standard replay strategy.
            replay_buffer.add(s0=x0,
                              action=action,
                              reward=rew,
                              s1=x,
                              done=float(done))

            logger.log(
                t_step, {
                    'q_value': ledger['q_value'].latest,
                    'q_value/mean': ledger['q_value/mean'].latest,
                    'q_value/var': ledger['q_value/var'].latest,
                    'q_value/action': ledger['action_q_value'].latest,
                    'q_value/action/mean':
                    ledger['action_q_value/mean'].latest,
                    'q_value/action/var': ledger['action_q_value/var'].latest
                },
                action=action,
                eps=eps,
                silent=True)

            if G.prioritized_replay:
                logger.log(t_step, beta=beta, silent=True)

            if done:
                ledger['timing/episode'].append(M.split('episode',
                                                        silent=True))
                ep_ind += 1
                x = np.array(env.reset())
                ledger['rewards'].append(info['total_reward'])

                silent = (ep_ind % Reporting.print_interval != 0)
                logger.log(t_step,
                           timestep=t_step,
                           episode=green(ep_ind),
                           total_reward=ledger['rewards'].latest,
                           episode_length=info['timesteps'],
                           silent=silent)
                logger.log(t_step, {
                    'total_reward/mean':
                    yellow(ledger['rewards'].mean, lambda v: f"{v:.1f}"),
                    'total_reward/max':
                    yellow(ledger['rewards'].max, lambda v: f"{v:.1f}"),
                    "time_spent_exploring":
                    default(eps, percent),
                    "timing/episode":
                    green(ledger['timing/episode'].latest, sec),
                    "timing/episode/mean":
                    green(ledger['timing/episode'].mean, sec),
                },
                           silent=silent)
                try:
                    logger.log(t_step, {
                        "timing/sample":
                        default(ledger['timing/sample'].latest, sec),
                        "timing/sample/mean":
                        default(ledger['timing/sample'].mean, sec),
                        "timing/train":
                        default(ledger['timing/train'].latest, sec),
                        "timing/train/mean":
                        green(ledger['timing/train'].mean, sec),
                        "timing/log_histogram":
                        default(ledger['timing/log_histogram'].latest, sec),
                        "timing/log_histogram/mean":
                        default(ledger['timing/log_histogram'].mean, sec)
                    },
                               silent=silent)
                    if G.prioritized_replay:
                        logger.log(t_step, {
                            "timing/update_priorities":
                            default(ledger['timing/update_priorities'].latest,
                                    sec),
                            "timing/update_priorities/mean":
                            default(ledger['timing/update_priorities'].mean,
                                    sec)
                        },
                                   silent=silent)
                except Exception as e:
                    pass
                if G.prioritized_replay:
                    logger.log(
                        t_step,
                        {"replay_beta": default(beta, lambda v: f"{v:.2f}")},
                        silent=silent)

            # note: learn here.
            if t_step >= G.learning_start and t_step % G.learn_interval == 0:
                if G.prioritized_replay:
                    experiences, weights, indices = replay_buffer.sample(
                        G.replay_batch_size, beta)
                    logger.log_histogram(t_step, weights=weights)
                else:
                    experiences, weights = replay_buffer.sample(
                        G.replay_batch_size), None
                M.tic('train', silent=True)
                x0s, actions, rewards, x1s, dones = zip(*experiences)
                td_error_val, loss_val = trainer.train(s0s=x0s,
                                                       actions=actions,
                                                       rewards=rewards,
                                                       s1s=x1s,
                                                       dones=dones,
                                                       sample_weights=weights)
                ledger['timing/train'].append(M.toc('train', silent=True))
                M.tic('log_histogram', silent=True)
                logger.log_histogram(t_step, td_error=td_error_val)
                ledger['timing/log_histogram'].append(
                    M.toc('log_histogram', silent=True))
                if G.prioritized_replay:
                    M.tic('update_priorities', silent=True)
                    new_priorities = np.abs(td_error_val) + eps
                    replay_buffer.update_priorities(indices, new_priorities)
                    ledger['timing/update_priorities'].append(
                        M.toc('update_priorities', silent=True))

            if t_step % G.target_network_update_interval == 0:
                trainer.update_target()

            if t_step % Reporting.checkpoint_interval == 0:
                U.save_state(os.path.join(RUN.log_directory, RUN.checkpoint))