Ejemplo n.º 1
0
    def __init__(self, state_size, action_size, seed, Dueling_Normal):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        if Dueling_Normal == 'Normal':
            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
            self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
        # Dueling Q-Network
        elif Dueling_Normal == 'Dueling':
            self.qnetwork_local = DuelQNetwork(state_size, action_size, seed).to(device)
            self.qnetwork_target = DuelQNetwork(state_size, action_size, seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
        #self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, 10, 2)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
Ejemplo n.º 2
0
    def __init__(self,
                 state_size,
                 action_size,
                 seed,
                 lr_decay=0.9999,
                 double_dqn=False,
                 duel_dqn=False,
                 prio_exp=False):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): Dimension of each State
            action_size (int): Dimension of each Action
            seed (int): Random Seed
            lr_decay (float): Decay float for alpha learning rate
            DOUBLE DQN (boolean): Indicator for Double Deep Q-Network
            DUEL DQN (boolean): Indicator for Duel Deep Q-Network
            PRIORITISED_EXPERIENCE (boolean): Indicator for Prioritized Experience Replay
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.lr_decay = lr_decay
        self.DOUBLE_DQN = double_dqn
        self.DUEL_DQN = duel_dqn
        self.PRIORITISED_EXPERIENCE = prio_exp

        # Determine Deep Q-Network for use
        if self.DUEL_DQN:
            self.qnetwork_local = DuelQNetwork(state_size, action_size,
                                               seed).to(device)
            self.qnetwork_target = DuelQNetwork(state_size, action_size,
                                                seed).to(device)
        else:
            self.qnetwork_local = QNetwork(state_size, action_size,
                                           seed).to(device)
            self.qnetwork_target = QNetwork(state_size, action_size,
                                            seed).to(device)

        # Initialize Optimizer
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Determine if Prioritized Experience will be used
        if self.PRIORITISED_EXPERIENCE:
            self.memory = PrioritizedReplayBuffer(action_size,
                                                  BUFFER_SIZE,
                                                  BATCH_SIZE,
                                                  seed,
                                                  alpha=0.6,
                                                  beta=0.4,
                                                  beta_anneal=1.0001)
        else:
            self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                       seed)

        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
Ejemplo n.º 3
0
    def __init__(self, state_size, action_size, seed,
                 buffer_size, batch_size, gamma, tau,
                 lr, update_every, update_target_network_every,
                 alpha, sequence_length,
                 use_dueling_network):
        """Base class for agents

        Args:
        state_size (flat): the number of states
        action_size (float): the number of actions
        seed (int): random seed
        buffer_size (int, optional): the maximum number of elements in the replay buffer
        batch_size (int, optional): mini batch size
        gamma (float, optional): reward discount rate
        tau (int, optional): target network update rate. Use 1 if the target network is updated entirely from the local network at once
        lr (float, optional): network learning rate on the Adam optimizer
        update_every (int, optional): how often we learn the learning step (i.e. 4 means the learning step is executed ever 4 action taken)
        update_target_network_every (int, optional): how often the target network is updated from the local network
        alpha (float, optional): the weight to control the importance of the priority to calculate sampling probability (0 means random sampling)
        sequence_length (int, optional): if this is greater than 1, the recurrent dueling DQN is used. Use 1 for non-recurrent networks.
        use_dueling_network (bool, optional): true to use the dueling DQN
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        if use_dueling_network is True:
            if sequence_length > 1:
                self.qnetwork_local = DuelQGRUNetwork(state_size, action_size, seed).to(self.device)
                self.qnetwork_target = DuelQGRUNetwork(state_size, action_size, seed).to(self.device)
            else:
                self.qnetwork_local = DuelQNetwork(state_size, action_size, seed).to(self.device)
                self.qnetwork_target = DuelQNetwork(state_size, action_size, seed).to(self.device)
        else:
            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(self.device)
            self.qnetwork_target = QNetwork(state_size, action_size, seed).to(self.device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)

        self.batch_size = batch_size
        self.update_every = update_every
        self.update_target_network_every = update_target_network_every
        self.gamma = gamma
        self.tau = tau
        self.sequence_length = sequence_length

        # Replay memory
        self.memory = ReplayBuffer(action_size, buffer_size, batch_size, seed, self.device, alpha, sequence_length)

        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        self.u_step = 0  # used to control how often we copy local parameters to target parameters

        self.resetSequences()
Ejemplo n.º 4
0
    def __init__(self,
                 state_size,
                 action_size,
                 seed,
                 dqn_type="double",
                 dueling=True):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
            dqn_type: can be simple, double, dual
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.dqn_type = dqn_type

        # Q-Network
        if dueling:
            self.qnetwork_local = DuelQNetwork(state_size, action_size,
                                               seed).to(device)
            self.qnetwork_target = DuelQNetwork(state_size, action_size,
                                                seed).to(device)
        else:
            self.qnetwork_local = QNetwork(state_size, action_size,
                                           seed).to(device)
            self.qnetwork_target = QNetwork(state_size, action_size,
                                            seed).to(device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
        # self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, 10, 2)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
Ejemplo n.º 5
0
    def __init__(self, state_size, action_size, seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """

        self.state_size = state_size[0]
        self.action_size = action_size
        self.seed = random.seed(seed)

        # DuelQNetwork
        self.qnetwork_local = DuelQNetwork(self.state_size, self.action_size, seed).to(device)
        self.qnetwork_target = DuelQNetwork(self.state_size, self.action_size, seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Prioritized Experienced Replay memory
        self.memory = PrioritizedReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
Ejemplo n.º 6
0
    def reset(self):
        np.random.seed(self.seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        self.t_last_update = 0
        self.tau = 1.0
        self.prev_prev_act = -1
        self.prev_act = -1
        self.lr = self.lr_max
        self.lr_at_minimum = False
        if self.verbose_level > 0:
            print("Reset DDQN agent with parameters:")
            print("state_size=\t{}".format(self.state_size))
            print("action_size=\t{}".format(self.action_size))
            print("hidden_layers=\t", self.hidden_layers)
            print("seed=\t{}".format(self.seed))
            print("update_every=\t{}".format(self.update_every))
            print("batch_size=\t{}".format(self.batch_size))
            print("buffer_size=\t{}".format(self.buffer_size))
            print("learning_rate=\t{:.4e}".format(self.lr))
            print("tau=\t{:.3f}".format(self.tau0))
            print("random_walk=\t", self.random_walk)
            print("gamma=\t{:.3f}".format(self.gamma))

        # Q-Network
        self.qnetwork_local = DuelQNetwork(self.state_size, self.action_size,
                                           self.hidden_layers,
                                           self.seed).to(self.device)
        self.qnetwork_target = DuelQNetwork(self.state_size, self.action_size,
                                            self.hidden_layers,
                                            self.seed).to(self.device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=self.lr)
        # Replay memory
        self.memory = ReplayBuffer(self.state_size, self.action_size,
                                   self.buffer_size, self.batch_size,
                                   self.seed)
Ejemplo n.º 7
0
class Agent():
    """Interacts with and learns from the environment."""

    def __init__(self, state_size, action_size, seed, Dueling_Normal):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        if Dueling_Normal == 'Normal':
            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
            self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
        # Dueling Q-Network
        elif Dueling_Normal == 'Dueling':
            self.qnetwork_local = DuelQNetwork(state_size, action_size, seed).to(device)
            self.qnetwork_target = DuelQNetwork(state_size, action_size, seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
        #self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, 10, 2)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
    
    def step(self, state, action, reward, next_state, done, type_dqn):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)
        
        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA, type_dqn)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma, type_dqn):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences
        if type_dqn == 'simple':
            ########## DQN
            # Get max predicted Q values (for next states) from target model
            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)

        elif type_dqn == 'double':
        ########## Double DQN
            max_local_action = torch.argmax(self.qnetwork_local(next_states), 1) # list of length benchsize
            Q_targets_next = self.qnetwork_target(next_states)
            Q_targets_next = Q_targets_next[torch.arange(Q_targets_next.shape[0]), max_local_action].unsqueeze(1) # batchsize * 1
 
        # Compute Q targets for current states 
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        Q_targets = Q_targets.detach() # We dont backward on this network

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)
        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        #self.scheduler()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)                     

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
Ejemplo n.º 8
0
class ddqn_agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size,\
                 hidden_layers = [[64,64,32],[],[]],\
                 update_every  = 4,\
                 batch_size    = 128,\
                 buffer_size   = int(1e5),\
                 learning_rate = 5e-4,\
                 tau           = 1e-3,\
                 gamma         = 0.99,\
                 random_walk   = [0.75, 0.05, 0.1, 0.1],\
                 device        = None,\
                 verbose_level = 2):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int):  dimension of each state
            action_size (int):  dimension of each action
            hidden_layers (list of lists of ints): hidden layers structure
            update_every (int): how often to update the network
            batch_size (int):   minibatch size for sampling replay buffer and train the network
            learning_rate (float): learning rate of the local network
            tau(float): for soft update of target parameters
            gamma (float):  discount factor of next state Q value
            random_walk (array-like of action_size ints): apriory probabilities of random walk
            device: cpu or gpu
        """
        if device is None:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        self.state_size = state_size
        self.action_size = action_size
        self.action_range = np.arange(self.action_size)
        self.seed = 0
        self.update_every = update_every
        self.batch_size = batch_size
        assert type(learning_rate) is float
        self.lr_max = learning_rate
        self.lr_min = 1e-5
        self.lr_decay = 0.5
        self.hidden_layers = hidden_layers
        self.tau0 = tau
        self.buffer_size = buffer_size
        self.gamma = gamma
        self.verbose_level = verbose_level
        self.set_random_walk_probabilities(random_walk)
        self.reset()

    def reset(self):
        np.random.seed(self.seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        self.t_last_update = 0
        self.tau = 1.0
        self.prev_prev_act = -1
        self.prev_act = -1
        self.lr = self.lr_max
        self.lr_at_minimum = False
        if self.verbose_level > 0:
            print("Reset DDQN agent with parameters:")
            print("state_size=\t{}".format(self.state_size))
            print("action_size=\t{}".format(self.action_size))
            print("hidden_layers=\t", self.hidden_layers)
            print("seed=\t{}".format(self.seed))
            print("update_every=\t{}".format(self.update_every))
            print("batch_size=\t{}".format(self.batch_size))
            print("buffer_size=\t{}".format(self.buffer_size))
            print("learning_rate=\t{:.4e}".format(self.lr))
            print("tau=\t{:.3f}".format(self.tau0))
            print("random_walk=\t", self.random_walk)
            print("gamma=\t{:.3f}".format(self.gamma))

        # Q-Network
        self.qnetwork_local = DuelQNetwork(self.state_size, self.action_size,
                                           self.hidden_layers,
                                           self.seed).to(self.device)
        self.qnetwork_target = DuelQNetwork(self.state_size, self.action_size,
                                            self.hidden_layers,
                                            self.seed).to(self.device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=self.lr)
        # Replay memory
        self.memory = ReplayBuffer(self.state_size, self.action_size,
                                   self.buffer_size, self.batch_size,
                                   self.seed)

    def set_random_walk_probabilities(self,
                                      random_walk=[0.75, 0.05, 0.1, 0.1]):
        if len(random_walk) == self.action_size:
            t = type(random_walk)
            if t == list:
                self.random_walk = np.array(random_walk) / np.sum(random_walk)
                return
            if t == np.ndarray:
                self.random_walk = random_walk / np.sum(random_walk)
                return
            raise
        # random_walk is a number between 0.0 to 1.0
        # 0.0 --> equal probability   1.0 --> only forward walk
        # changing random walk probability to favor forward action over the rest of actions
        equal_prob = 1.0 / self.action_size
        forward_prob = random_walk * 1.0 + (1.0 - random_walk) * equal_prob
        other_prob = (1.0 - forward_prob) / (self.action_size - 1)
        self.random_walk = np.full(self.action_size, fill_value=other_prob)
        self.random_walk[0] = forward_prob

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        next_value_multiplier = self.gamma * (1 - done)
        self.memory.add(state, action, reward, next_state,
                        next_value_multiplier)

        # Learn every UPDATE_EVERY time steps.
        self.t_step += 1
        if self.t_step < self.t_last_update + self.update_every:
            return
        self.t_last_update = self.t_step
        experiences = self.memory.sample()
        if experiences is None:
            # If enough samples are available in memory, get random subset and learn
            return
        states, actions, rewards, next_states, next_value_multipliers = experiences
        states = torch.from_numpy(states).float().to(self.device)
        actions = torch.from_numpy(actions).long().to(self.device)
        rewards = torch.from_numpy(rewards).float().to(self.device)
        next_states = torch.from_numpy(next_states).float().to(self.device)
        next_value_multipliers = torch.from_numpy(
            next_value_multipliers).float().to(self.device)

        self.qnetwork_target.train()  # batch norm can update itself
        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + next_value_multipliers * Q_targets_next

        # Get expected Q values from local model
        self.qnetwork_local.train()
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- soft update target network ------------------- #
        for target_param, local_param in zip(self.qnetwork_target.parameters(),
                                             self.qnetwork_local.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)
        self.tau = min(max(self.tau / (1 + self.tau - self.tau0), self.tau0),
                       self.tau)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        # Epsilon-greedy action selection
        exploitation = np.random.random() >= eps
        if exploitation:
            # exploit the agent knowledge
            state = torch.from_numpy(state).float().unsqueeze(0).to(
                self.device)
            self.qnetwork_local.eval()
            with torch.no_grad():
                action_values = self.qnetwork_local(state).cpu().data.numpy()
            new_act = np.argmax(action_values)
        else:
            # explore the environment using random move
            # try to employ a "common sense" in order to avoid redundant random moves
            redundant = -1
            if self.prev_act == 0 and self.prev_prev_act == 1:
                # prev step was forward, before that was backward --> inhibit backward
                redundant = 1
            elif self.prev_act == 1 and self.prev_prev_act == 0:
                # prev step was backward, before that was forward --> inhibit forward
                redundant = 0
            elif self.prev_act == 2 and self.prev_prev_act == 3:
                # prev step was left, before that was right --> inhibit right
                redundant = 3
            elif self.prev_act == 3 and self.prev_prev_act == 2:
                # prev step was right, before that was left --> inhibit left
                redundant = 2
            random_walk = self.random_walk
            if redundant >= 0:
                # take the redundant move probability and split it to the rest of possible actions
                redundant_prob = random_walk[redundant]
                random_walk += redundant_prob / (self.action_size - 1)
                random_walk[redundant] = 0
            new_act = np.random.choice(self.action_range, p=random_walk)
        self.prev_prev_act = self.prev_act
        self.prev_act = new_act
        return new_act, exploitation

    def learning_rate_step(self):
        self.lr *= self.lr_decay
        if self.lr <= self.lr_min:
            self.lr_at_minimum = True
            self.lr = self.lr_min
        if self.verbose_level > 1:
            print("\nChanging learning rate to: {:.4e}".format(self.lr))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def is_lr_at_minimum(self):
        if self.lr_at_minimum and self.verbose_level > 1:
            print(
                "\nCannot reduce learning rate because it is already at the minimum: {:.4e}"
                .format(self.lr))
        return self.lr_at_minimum

    def save(self, filename):
        shutil.rmtree(filename,
                      ignore_errors=True)  # avoid file not found error
        os.makedirs(filename)
        torch.save(self.qnetwork_local.state_dict(),
                   os.path.join(filename, "local.pth"))
        torch.save(self.qnetwork_target.state_dict(),
                   os.path.join(filename, "target.pth"))
        torch.save(self.optimizer.state_dict(),
                   os.path.join(filename, "optimizer.pth"))

    def load(self, filename):
        self.qnetwork_local.load_state_dict(
            torch.load(os.path.join(filename, "local.pth")))
        self.qnetwork_target.load_state_dict(
            torch.load(os.path.join(filename, "target.pth")))
        self.optimizer.load_state_dict(
            torch.load(os.path.join(filename, "optimizer.pth")))
Ejemplo n.º 9
0
class Agent():
    """Interacts with and learns from the environment."""

    def __init__(self, state_size, action_size, seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """

        self.state_size = state_size[0]
        self.action_size = action_size
        self.seed = random.seed(seed)

        # DuelQNetwork
        self.qnetwork_local = DuelQNetwork(self.state_size, self.action_size, seed).to(device)
        self.qnetwork_target = DuelQNetwork(self.state_size, self.action_size, seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Prioritized Experienced Replay memory
        self.memory = PrioritizedReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, state, action, reward, next_state, done, beta=1.):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every update_every time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample(ALPHA, beta)
                # Is line below required? Don't think so looks like no-op ...
                # action_values = self.qnetwork_local(experiences[0])
                self.learn(experiences, GAMMA)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.

        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if action_values is not None and random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones, weights, indices = experiences

        # Get max action from local model
        local_max_actions = self.qnetwork_local(next_states).detach().max(1)[1].unsqueeze(1)

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = torch.gather(self.qnetwork_target(next_states).detach(), 1, local_max_actions)

        # Compute Q targets for current states
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        # loss = F.mse_loss(Q_expected, Q_targets)
        loss = (Q_expected - Q_targets).pow(2) * weights
        adjusted_loss = loss + 1e-5
        loss = loss.mean()

        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update priorities based on td error
        self.memory.update_priorities(indices.squeeze().to(device).data.numpy(), adjusted_loss.squeeze().to(device).data.numpy())

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)