Beispiel #1
0
class Learner():
    def __init__(self, sess, s_size, a_size, scope, queues, trainer):
        self.queue = queues[0]
        self.param_queue = queues[1]
        self.replaymemory = ReplayMemory(100000)
        self.sess = sess
        self.learner_net = network(s_size, a_size, scope, 20)

        self.q = self.learner_net.q
        self.Q = self.learner_net.Q

        self.actions_q = tf.placeholder(shape=[None, a_size, N],
                                        dtype=tf.float32)
        self.q_target = tf.placeholder(shape=[None, N], dtype=tf.float32)
        self.ISWeights = tf.placeholder(shape=[None, N], dtype=tf.float32)

        self.q_actiona = tf.multiply(self.q, self.actions_q)
        self.q_action = tf.reduce_sum(self.q_actiona, axis=1)
        self.u = tf.abs(self.q_target - self.q_action)
        self.loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(self.u) * self.ISWeights, axis=1))

        self.local_vars = self.learner_net.local_vars  #tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        self.gradients = tf.gradients(self.loss, self.local_vars)
        #grads,self.grad_norms = tf.clip_by_norm(self.gradients,40.0)
        self.apply_grads = trainer.apply_gradients(
            zip(self.gradients, self.local_vars))
        self.sess.run(tf.global_variables_initializer())

    def run(self, gamma, s_size, a_size, batch_size, env):
        print('start learning')
        step, train1 = 0, False
        epi_q = []
        self.env = env
        while True:
            if self.queue.empty():
                pass
            else:
                while not self.queue.empty():
                    t_error = self.queue.get()
                    step += 1
                    self.replaymemory.add(t_error)

            if self.param_queue.empty():
                params = self.sess.run(self.local_vars)
                self.param_queue.put(params)

            if step >= 10000:
                train1 = True
                step = 0

            if train1 == True:
                episode_buffer, tree_idx, ISWeights = self.replaymemory.sample(
                    batch_size)
                #print 'fadsfdasfadsfa'
                episode_buffer = np.array(episode_buffer)
                #print episode_buffer
                observations = episode_buffer[:, 0]
                actions = episode_buffer[:, 1]
                rewards = episode_buffer[:, 2]
                observations_next = episode_buffer[:, 3]
                dones = episode_buffer[:, 4]
                Q_target = self.sess.run(self.Q,
                                         feed_dict={
                                             self.learner_net.inputs:
                                             np.vstack(observations_next)
                                         })

                actions_ = np.argmax(Q_target, axis=1)

                action = np.zeros((batch_size, a_size))
                action_ = np.zeros((batch_size, a_size))
                for i in range(batch_size):
                    action[i][actions[i]] = 1
                    action_[i][actions_[i]] = 1
                action_now = np.zeros((batch_size, a_size, N))
                action_next = np.zeros((batch_size, a_size, N))
                for i in range(batch_size):
                    for j in range(a_size):
                        for k in range(N):
                            action_now[i][j][k] = action[i][j]
                            action_next[i][j][k] = action_[i][j]
                q_target = self.sess.run(self.q_action,
                                         feed_dict={
                                             self.learner_net.inputs:
                                             np.vstack(observations_next),
                                             self.actions_q:
                                             action_next
                                         })

                q_target_batch = []
                for i in range(len(q_target)):
                    qi = q_target[i]
                    z_target_step = []
                    for j in range(len(qi)):
                        z_target_step.append(gamma * qi[j] * (1 - dones[i]) +
                                             rewards[i])
                    q_target_batch.append(z_target_step)
                q_target_batch = np.array(q_target_batch)

                isweight = np.zeros((batch_size, N))
                for i in range(batch_size):
                    for j in range(N):
                        isweight[i, j] = ISWeights[i]
                feed_dict = {
                    self.q_target: q_target_batch,
                    self.learner_net.inputs: np.vstack(observations),
                    self.actions_q: action_now,
                    self.ISWeights: isweight
                }

                l, abs_errors, _ = self.sess.run(
                    [self.loss, self.u, self.apply_grads], feed_dict=feed_dict)
                #print abs_errors
                abs_errors = np.mean(abs_errors, axis=1) + 1e-6

                self.replaymemory.update_priorities(tree_idx, abs_errors)
class Agent():
    def __init__(self, game, agent_type, display, load_model, record, test):
        self.name = game
        self.agent_type = agent_type
        self.ale = ALEInterface()
        self.ale.setInt(str.encode('random_seed'), np.random.randint(100))
        self.ale.setBool(str.encode('display_screen'), display or record)
        if record:
            self.ale.setString(str.encode('record_screen_dir'), str.encode('./data/recordings/{}/{}/tmp/'.format(game, agent_type)))

        self.ale.loadROM(str.encode('./roms/{}.bin'.format(self.name)))
        self.action_list = list(self.ale.getMinimalActionSet())
        self.frame_shape = np.squeeze(self.ale.getScreenGrayscale()).shape
        if test:
            self.name += '_test'

        if 'space_invaders' in self.name:
            # Account for blinking bullets
            self.frameskip = 2
        else:
            self.frameskip = 3

        self.frame_buffer = deque(maxlen=4)
        if load_model and not record:
            self.load_replaymemory()
        else:
            self.replay_memory = ReplayMemory(500000, 32)

        model_input_shape = self.frame_shape + (4,)
        model_output_shape = len(self.action_list)

        if agent_type == 'dqn':
            self.model = DeepQN(
                model_input_shape,
                model_output_shape,
                self.action_list,
                self.replay_memory,
                self.name,
                load_model
            )
        elif agent_type == 'double':
            self.model = DoubleDQN(
                model_input_shape,
                model_output_shape,
                self.action_list,
                self.replay_memory,
                self.name,
                load_model
            )

        else:
            self.model = DuelingDQN(
                model_input_shape,
                model_output_shape,
                self.action_list,
                self.replay_memory,
                self.name,
                load_model
            )

        print('{} Loaded!'.format(' '.join(self.name.split('_')).title()))
        print('Displaying: ', display)
        print('Frame Shape: ', self.frame_shape)
        print('Frame Skip: ', self.frameskip)
        print('Action Set: ', self.action_list)
        print('Model Input Shape: ', model_input_shape)
        print('Model Output Shape: ', model_output_shape)
        print('Agent: ', agent_type)

    def training(self, steps):
        '''
        Trains the agent for :steps number of weight updates.

        Returns the average model loss
        '''

        loss = []

        # Initialize frame buffer. np.squeeze removes empty dimensions e.g. if shape=(210,160,__)
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))

        try:
            for step in range(steps):
                gameover = False
                initial_state = np.stack(self.frame_buffer, axis=-1)
                action = self.model.predict_action(initial_state)

                # Backup data
                if step % 5000 == 0:
                    self.model.save_model()
                    self.model.save_hyperparams()
                    self.save_replaymemory()

                # If using a target model check for weight updates
                if hasattr(self.model, 'tau'):
                    if self.model.tau == 0:
                        self.model.update_target_model()
                        self.model.tau = 10000
                    else:
                        self.model.tau -= 1

                # Frame skipping technique https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
                lives_before = self.ale.lives()
                for _ in range(self.frameskip):
                    self.ale.act(action)

                reward = self.ale.act(action)
                self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
                lives_after = self.ale.lives()

                if lives_after < lives_before:
                    gameover = True  # Taking advice from dude on reddit
                    reward = -1

                if self.ale.game_over():
                    gameover = True
                    reward = -1
                    self.ale.reset_game()

                new_state = np.stack(self.frame_buffer, axis=-1)

                # Experiment with clipping rewards for stability purposes
                reward = np.clip(reward, -1, 1)
                self.replay_memory.add(
                    initial_state,
                    action,
                    reward,
                    gameover,
                    new_state
                )

                loss += self.model.replay_train()
        except:
            self.model.save_model()
            self.model.save_hyperparams()
            self.save_replaymemory()
            raise KeyboardInterrupt

        return np.mean(loss, axis=0)

    def simulate_random(self):
        print('Simulating game randomly')
        done = False
        total_reward = 0
        while not done:
            action = np.random.choice(self.ale.getMinimalActionSet())
            reward = self.ale.act(action)
            total_reward += reward
            if self.ale.game_over():
                reward = -1
                done = True

            reward = np.clip(reward, -1, 1)
            if reward != 0:
                print(reward)

        frames_survived = self.ale.getEpisodeFrameNumber()
        self.ale.reset_game()
        return total_reward, frames_survived

    def simulate_intelligent(self, evaluating=False):
        done = False
        total_score = 0

        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
        while not done:
            state = np.stack(self.frame_buffer, axis=-1)
            action = self.model.predict_action(state, evaluating)

            for _ in range(self.frameskip):
                self.ale.act(action)

            # Remember, ale.act returns the increase in game score with this action
            total_score += self.ale.act(action)

            # Pushes oldest frame out
            self.frame_buffer.append(np.squeeze(self.ale.getScreenGrayscale()))
            if self.ale.game_over():
                done = True

        frames_survived = self.ale.getEpisodeFrameNumber()
        print('   Game Over')
        print('   Frames Survived: ', frames_survived)
        print('   Score: ', total_score)
        print('===========================')
        self.ale.reset_game()
        return total_score, frames_survived

    def save_replaymemory(self):
        with bz2.BZ2File('./data/{}/{}_replaymem.obj'.format(self.agent_type, self.name), 'wb') as f:
            pickle.dump(self.replay_memory, f, protocol=pickle.HIGHEST_PROTOCOL)
            print('Saved replay memory at ', datetime.now())

    def load_replaymemory(self):
        try:
            with bz2.BZ2File('./data/{}/{}_replaymem.obj'.format(self.agent_type, self.name), 'rb') as f:
                self.replay_memory = pickle.load(f)
                print('Loaded replay memory at ', datetime.now())
        except FileNotFoundError:
            print('No replay memory file found')
            raise KeyboardInterrupt
Beispiel #3
0
class Worker():
    def __init__(self,env,name,s_size,a_size,trainer,model_path,global_episodes):
        self.name = "worker_" + str(name)
        self.number = name
        self.model_path = model_path
        self.trainer = trainer
        self.global_episodes = global_episodes
        self.increment = self.global_episodes.assign_add(1)
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_mean_values = []
        #Create the local copy of the network and the tensorflow op to copy global paramters to local network
        self.local_Q = Q_Network(s_size, a_size, self.name, trainer)
        self.update_local_ops = update_target_graph('global', self.name)
        self.env = env
        self.replaymemory = ReplayMemory(max_memory)
        
    def train(self,rollout,sess,gamma,ISWeights):
        rollout = np.array(rollout)
        observations      = rollout[:,0]
        actions           = rollout[:,1]
        rewards           = rollout[:,2]
        next_observations = rollout[:,3]
        dones             = rollout[:,4]
        
        Q_target = sess.run(self.local_Q.Q, feed_dict={self.local_Q.inputs:np.vstack(next_observations)})
        actions_ = np.argmax(Q_target, axis=1)
        
        action = np.zeros((batch_size, a_size))
        action_ = np.zeros((batch_size, a_size))
        for i in range(batch_size):
            action[i][actions[i]] = 1
            action_[i][actions_[i]] = 1
        action_now = np.zeros((batch_size, a_size, N))
        action_next = np.zeros((batch_size, a_size, N))
        for i in range(batch_size):
            for j in range(a_size):
                for k in range(N):
                    action_now[i][j][k] = action[i][j]
                    action_next[i][j][k] = action_[i][j]

        q_target = sess.run(self.local_Q.q_action, feed_dict={self.local_Q.inputs:np.vstack(next_observations),
                                                               self.local_Q.actions_q:action_next})
        q_target_batch = []
        for i in range(len(q_target)):
            qi = q_target[i]# * (1 - dones[i])
            z_target_step = []
            for j in range(len(qi)):
                z_target_step.append(gamma * qi[j] + rewards[i])
            q_target_batch.append(z_target_step)
        q_target_batch = np.array(q_target_batch)
        #print q_target_batch
        isweight = np.zeros((batch_size,N))
        for i in range(batch_size):
            for j in range(N):
                isweight[i,j] = ISWeights[i]
        feed_dict = {self.local_Q.inputs:np.vstack(observations),
                     self.local_Q.actions_q:action_now,
                     self.local_Q.q_target:q_target_batch,
                     self.local_Q.ISWeights:isweight}
        u,l,g_n,v_n,_ = sess.run([self.local_Q.u,
                                  self.local_Q.loss,
                                  self.local_Q.grad_norms,
                                  self.local_Q.var_norms,
                                  self.local_Q.apply_grads],feed_dict=feed_dict)
        return l/len(rollout), g_n, v_n, Q_target, u

    def work(self,gamma,sess,coord,saver):
        global GLOBAL_STEP
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        epsilon = 0.2
        
        print ("Starting worker " + str(self.number))
        best_mean_episode_reward = -float('inf')
        with sess.as_default(), sess.graph.as_default():
            while not coord.should_stop():
                sess.run(self.update_local_ops)
                #episode_buffer = []
                episode_reward = 0
                episode_step_count = 0
                d = False
                s = self.env.reset()
                s = process_frame(s)
                if epsilon > 0.01:
                    epsilon = epsilon * 0.997
                while not d:
                    #self.env.render()
                    GLOBAL_STEP += 1
                    #Take an action using probabilities from policy network output.
                    if random.random() > epsilon:
                        a_dist_list = sess.run(self.local_Q.Q, feed_dict={self.local_Q.inputs:[s]})
                        a_dist = a_dist_list[0]
                        a = np.argmax(a_dist)
                    else:
                        a = random.randint(0, 5)
                    
                    s1, r, d, _ = self.env.step(a)
                    if d == False:
                        s1 = process_frame(s1)
                    else:
                        s1 = s
                    self.replaymemory.add([s,a,r,s1,d])
                    episode_reward += r
                    s = s1                    
                    total_steps += 1
                    episode_step_count += 1
                    if total_steps % 2 == 0 and d != True and total_steps > 50000:
                        episode_buffer, tree_idx, ISWeights = self.replaymemory.sample(batch_size)
                        l,g_n,v_n,Q_target,u = self.train(episode_buffer,sess,gamma,ISWeights)
                        u = np.mean(u,axis=1) + 1e-6
                        self.replaymemory.update_priorities(tree_idx,u)
                        #sess.run(self.update_local_ops)
                    if d == True:
                        break
                sess.run(self.update_local_ops)
                self.episode_rewards.append(episode_reward)
                self.episode_lengths.append(episode_step_count)

                # Periodically save gifs of episodes, model parameters, and summary statistics.
                if episode_count % 5 == 0 and episode_count != 0 and total_steps > max_memory:
                    if self.name == 'worker_0' and episode_count % 5 == 0:
                        print('\n episode: ', episode_count, 'global_step:', \
                              GLOBAL_STEP, 'mean episode reward: ', np.mean(self.episode_rewards[-10:]), \
                              'epsilon: ', epsilon)
                    
                    print ('loss', l, 'Qtargetmean', np.mean(Q_target))
                    #print 'p_target', p_target
                    if episode_count % 100 == 0 and self.name == 'worker_0' and total_steps > 10000:
                        saver.save(sess,self.model_path+'/qr-dqn-'+str(episode_count)+'.cptk')
                        print ("Saved Model")

                    mean_reward = np.mean(self.episode_rewards[-100:])
                    if episode_count > 20 and best_mean_episode_reward < mean_reward:
                        best_mean_episode_reward = mean_reward

                if self.name == 'worker_0':
                    sess.run(self.increment)
                    #if episode_count%1==0:
                        #print('\r {} {}'.format(episode_count, episode_reward),end=' ')
                episode_count += 1
Beispiel #4
0
class Worker():
    def __init__(self, env, name, s_size, a_size, trainer, model_path,
                 global_episodes):
        self.name = "worker_" + str(name)
        self.number = name
        self.model_path = model_path
        self.trainer = trainer
        self.global_episodes = global_episodes
        self.increment = self.global_episodes.assign_add(1)
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_mean_values = []
        #Create the local copy of the network and the tensorflow op to copy global paramters to local network
        self.local_Q = Q_Network(s_size, a_size, self.name, trainer)
        self.update_local_ops = update_target_graph('global', self.name)
        self.env = env
        self.replaymemory = ReplayMemory(max_memory)

    def train(self, rollout, sess, gamma, ISWeights):
        rollout = np.array(rollout)
        observations = rollout[:, 0]
        actions = rollout[:, 1]
        rewards = rollout[:, 2]
        next_observations = rollout[:, 3]
        dones = rollout[:, 4]

        Q_target = sess.run(
            self.local_Q.Q,
            feed_dict={self.local_Q.inputs: np.vstack(next_observations)})
        actions_ = np.argmax(Q_target, axis=1)

        action = np.zeros((batch_size, a_size))
        action_ = np.zeros((batch_size, a_size))
        for i in range(batch_size):
            action[i][actions[i]] = 1
            action_[i][actions_[i]] = 1
        action_now = np.zeros((batch_size, a_size, N))
        action_next = np.zeros((batch_size, a_size, N))
        for i in range(batch_size):
            for j in range(a_size):
                for k in range(N):
                    action_now[i][j][k] = action[i][j]
                    action_next[i][j][k] = action_[i][j]

        q_target = sess.run(self.local_Q.q_action,
                            feed_dict={
                                self.local_Q.inputs:
                                np.vstack(next_observations),
                                self.local_Q.actions_q: action_next
                            })
        q_target_batch = []
        for i in range(len(q_target)):
            qi = q_target[i]  # * (1 - dones[i])
            z_target_step = []
            for j in range(len(qi)):
                z_target_step.append(gamma * qi[j] + rewards[i])
            q_target_batch.append(z_target_step)
        q_target_batch = np.array(q_target_batch)
        #print q_target_batch
        isweight = np.zeros((batch_size, N))
        for i in range(batch_size):
            for j in range(N):
                isweight[i, j] = ISWeights[i]
        feed_dict = {
            self.local_Q.inputs: np.vstack(observations),
            self.local_Q.actions_q: action_now,
            self.local_Q.q_target: q_target_batch,
            self.local_Q.ISWeights: isweight
        }
        u, l, g_n, v_n, _ = sess.run([
            self.local_Q.u, self.local_Q.loss, self.local_Q.grad_norms,
            self.local_Q.var_norms, self.local_Q.apply_grads
        ],
                                     feed_dict=feed_dict)
        return l / len(rollout), g_n, v_n, Q_target, u

    def work(self, gamma, sess, coord, saver):
        global GLOBAL_STEP
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        epsilon = 0.2

        print("Starting worker " + str(self.number))
        best_mean_episode_reward = -float('inf')
        with sess.as_default(), sess.graph.as_default():
            while not coord.should_stop():
                sess.run(self.update_local_ops)
                #episode_buffer = []
                episode_reward = 0
                episode_step_count = 0
                d = False
                s = self.env.reset()
                s = process_frame(s)
                if epsilon > 0.01:
                    epsilon = epsilon * 0.997
                while not d:
                    #self.env.render()
                    GLOBAL_STEP += 1
                    #Take an action using probabilities from policy network output.
                    if random.random() > epsilon:
                        a_dist_list = sess.run(
                            self.local_Q.Q,
                            feed_dict={self.local_Q.inputs: [s]})
                        a_dist = a_dist_list[0]
                        a = np.argmax(a_dist)
                    else:
                        a = random.randint(0, 5)

                    s1, r, d, _ = self.env.step(a)
                    if d == False:
                        s1 = process_frame(s1)
                    else:
                        s1 = s
                    self.replaymemory.add([s, a, r, s1, d])
                    episode_reward += r
                    s = s1
                    total_steps += 1
                    episode_step_count += 1
                    if total_steps % 2 == 0 and d != True and total_steps > 50000:
                        episode_buffer, tree_idx, ISWeights = self.replaymemory.sample(
                            batch_size)
                        l, g_n, v_n, Q_target, u = self.train(
                            episode_buffer, sess, gamma, ISWeights)
                        u = np.mean(u, axis=1) + 1e-6
                        self.replaymemory.update_priorities(tree_idx, u)
                        #sess.run(self.update_local_ops)
                    if d == True:
                        break
                sess.run(self.update_local_ops)
                self.episode_rewards.append(episode_reward)
                self.episode_lengths.append(episode_step_count)

                # Periodically save gifs of episodes, model parameters, and summary statistics.
                if episode_count % 5 == 0 and episode_count != 0 and total_steps > max_memory:
                    if self.name == 'worker_0' and episode_count % 5 == 0:
                        print('\n episode: ', episode_count, 'global_step:', \
                              GLOBAL_STEP, 'mean episode reward: ', np.mean(self.episode_rewards[-10:]), \
                              'epsilon: ', epsilon)

                    print('loss', l, 'Qtargetmean', np.mean(Q_target))
                    #print 'p_target', p_target
                    if episode_count % 100 == 0 and self.name == 'worker_0' and total_steps > 10000:
                        saver.save(
                            sess, self.model_path + '/qr-dqn-' +
                            str(episode_count) + '.cptk')
                        print("Saved Model")

                    mean_reward = np.mean(self.episode_rewards[-100:])
                    if episode_count > 20 and best_mean_episode_reward < mean_reward:
                        best_mean_episode_reward = mean_reward

                if self.name == 'worker_0':
                    sess.run(self.increment)
                    #if episode_count%1==0:
                    #print('\r {} {}'.format(episode_count, episode_reward),end=' ')
                episode_count += 1
Beispiel #5
0
class Agent:
    '''Interact with and learn from the environment.'''
    def __init__(self, state_size, action_size, seed, is_double_q=False):
        '''Initialize an Agent.

        Params
        ======
            state_size (int): the dimension of the state
            action_size (int): the number of actions
            seed (int): random seed
        '''

        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.t_step = 0  # Initialize time step (for tracking LEARN_EVERY_STEP and UPDATE_EVERY_STEP)
        self.running_loss = 0
        self.training_cnt = 0

        self.is_double_q = is_double_q

        self.qnetwork_local = QNetwork(self.state_size, self.action_size,
                                       seed).to(device)
        self.qnetowrk_target = QNetwork(self.state_size, self.action_size,
                                        seed).to(device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
        self.replay_memory = ReplayMemory(BATCH_SIZE, BUFFER_SIZE, seed)

    def act(self, state, mode, epsilon=None):
        '''Returns actions for given state as per current policy.
        
        Params
        ======
            state (array-like): current state
            mode (string): train or test
            epsilon (float): for epsilon-greedy action selection

        '''
        state = torch.from_numpy(state).float().unsqueeze(0).to(
            device)  # shape of state (1, state)

        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local.forward(state)
        self.qnetwork_local.train()

        if mode == 'test':
            action = np.argmax(action_values.cpu().data.numpy()
                               )  # pull action values from gpu to local cpu

        elif mode == 'train':
            if random.random() <= epsilon:  # random action
                action = random.choice(np.arange(self.action_size))
            else:  # greedy action
                action = np.argmax(action_values.cpu().data.numpy(
                ))  # pull action values from gpu to local cpu

        return action

    def step(self, state, action, reward, next_state, done):
        # add new experience in memory
        self.replay_memory.add(state, action, reward, next_state, done)

        # activate learning every few steps
        self.t_step = self.t_step + 1
        if self.t_step % LEARN_EVERY_STEP == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.replay_memory) >= BUFFER_SIZE:
                experiences = self.replay_memory.sample(device)
                self.learn(experiences, GAMMA)

    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor

        """

        # compute and minimize the loss
        states, actions, rewards, next_states, dones = experiences

        q_local_chosen_action_values = self.qnetwork_local.forward(
            states).gather(1, actions)
        q_target_action_values = self.qnetowrk_target.forward(
            next_states).detach()  # # detach from graph, don't backpropagate

        if self.is_double_q == True:
            q_local_next_actions = self.qnetwork_local.forward(
                next_states).detach().max(1)[1].unsqueeze(
                    1)  # shape (batch_size, 1)
            q_target_best_action_values = q_target_action_values.gather(
                1, q_local_next_actions)  # Double DQN

        elif self.is_double_q == False:
            q_target_best_action_values = q_target_action_values.max(
                1)[0].unsqueeze(1)  # shape (batch_size, 1)

        q_target_values = rewards + gamma * q_target_best_action_values * (
            1 - dones)  # zero value for terminal state

        td_errors = q_target_values - q_local_chosen_action_values

        loss = (td_errors**2).mean()

        self.running_loss += float(loss.cpu().data.numpy())
        self.training_cnt += 1

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        if self.t_step % UPDATE_EVERY_STEP == 0:
            self.update(self.qnetwork_local, self.qnetowrk_target)

    def update(self, local_netowrk, target_network):
        """Hard update model parameters, as indicated in original paper.
        
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
        """
        for local_param, target_param in zip(local_netowrk.parameters(),
                                             target_network.parameters()):
            target_param.data.copy_(local_param.data)