示例#1
0
class Agent_DQN(Agent):
    def __init__(self, env, args):
        """
        Initialize everything you need here.
        For example: 
            paramters for neural network  
            initialize Q net and target Q net
            parameters for repaly buffer
            parameters for q-learning; decaying epsilon-greedy
            ...
        """
        super(Agent_DQN, self).__init__(env)
        ###########################
        # YOUR IMPLEMENTATION HERE #
        # import arguments
        self.args = args
        self.env = env
        self.batch_size = self.args.batch_size
        self.gamma = self.args.gamma
        self.lr = self.args.learning_rate
        self.memory_cap = self.args.memory_cap
        self.n_episode = self.args.n_episode
        self.n_step = self.args.n_step
        self.update_f = self.args.update_f
        self.explore_step = self.args.explore_step
        self.action_size = self.args.action_size
        self.algorithm = self.args.algorithm
        self.save_path = "dqn/"
        print('using algorithm ', self.algorithm)

        # whether continue training
        self.load_model = self.args.load_model

        # unify tensor tpye according to device names
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        print('using device ', torch.cuda.get_device_name(0))
        self.FloatTensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if self.use_cuda else torch.LongTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_cuda else torch.ByteTensor
        self.Tensor = self.FloatTensor  # default type

        # epsilon decay
        self.epsilon = 1.0
        self.epsilon_min = 0.025
        self.epsilon_decay = (self.epsilon -
                              self.epsilon_min) / self.explore_step

        # Create the policy net and the target net
        self.policy_net = DQN()
        self.policy_net.to(self.device)
        if self.algorithm == 'DDQN':
            self.policy_net_2 = DQN()
            self.policy_net_2.to(self.device)
        self.target_net = DQN()
        self.target_net.to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        # replay buffer
        self.memory = []

        # optimizer
        self.optimizer = optim.Adam(params=self.policy_net.parameters(),
                                    lr=self.lr)
        if self.algorithm == 'DDQN':
            self.optimizer_2 = optim.Adam(
                params=self.policy_net_2.parameters(), lr=self.lr)

        # other
        self.f_skip = 4  # frame skip
        self.n_avg_reward = 100
        self.f_print = 100
        self.print_test = False

        if args.test_dqn:
            #you can load your model here
            print('loading trained model')
            ###########################
            # YOUR IMPLEMENTATION HERE #
            self.policy_net.load_state_dict(
                torch.load('model.pth', map_location=self.device))
            self.target_net.load_state_dict(self.policy_net.state_dict())
            if self.algorithm == 'DDQN':
                self.policy_net_2.load_state_dict(
                    torch.load('model.pth', map_location=self.device))
            self.print_test = True

    def init_game_setting(self):
        """
        Testing function will call this function at the begining of new game
        Put anything you want to initialize if necessary.
        If no parameters need to be initialized, you can leave it as blank.
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        state = self.env.reset() / 255.
        self.last_life = 5
        self.step = 0
        done = False
        total_reward = 0
        ###########################
        return state, done, total_reward

    def make_action(self, observation, test=False):
        """
        Return predicted action of your agent
        Input:
            observation: np.array
                stack 4 last preprocessed frames, shape: (84, 84, 4)
        Return:
            action: int
                the predicted action from trained model
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        if test:
            self.epsilon = self.epsilon_min
            observation = observation / 255.
        else:
            self.epsilon = max(self.epsilon - self.epsilon_decay,
                               self.epsilon_min)
        if random.random() > self.epsilon:
            observation = self.Tensor(observation.reshape(
                (1, 84, 84, 4))).transpose(1, 3).transpose(2, 3)
            state_action_value = self.policy_net(
                observation).data.cpu().numpy()
            action = np.argmax(state_action_value)
        else:
            action = random.randint(0, self.action_size - 1)
        ###########################
        return action

    def push(self, state, action, reward, next_state, dead, done):
        """ You can add additional arguments as you need. 
        Push new data to buffer and remove the old one if the buffer is full.
        
        Hints:
        -----
            you can consider deque(maxlen = 10000) list
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        if len(self.memory) >= self.memory_cap:
            self.memory.pop(0)
        self.memory.append((state, action, reward, next_state, dead, done))
        ###########################

    def replay_buffer(self):
        """ You can add additional arguments as you need.
        Select batch from buffer.
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        self.mini_batch = random.sample(self.memory, self.batch_size)
        ###########################
        return

    def train(self):
        """
        Implement your training algorithm here
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        # initialize
        self.steps_done = 0
        self.steps = []
        self.rewards = []
        self.mean_rewards = []
        self.best_reward = 0
        self.last_saved_reward = 0

        start = time.time()
        logfile = open('dqn.log', 'w+')
        # continue training
        if self.load_model:
            self.policy_net.load_state_dict(
                torch.load(self.save_path + 'model.pth',
                           map_location=self.device))
            self.target_net.load_state_dict(self.policy_net.state_dict())
            self.epsilon = self.epsilon_min

        for episode in range(self.n_episode):
            state, done, total_reward = self.init_game_setting()
            while (not done) and self.step < 10000:
                # move to next state
                self.step += 1
                self.steps_done += 1
                action = self.make_action(state)
                next_state, reward, done, life = self.env.step(action)
                # lives matter
                now_life = life['ale.lives']
                dead = (now_life < self.last_life)
                self.last_life = now_life
                next_state = next_state / 255.
                # Store the transition in memory
                self.push(state, action, reward, next_state, dead, done)
                state = next_state
                total_reward += reward

                if len(self.memory
                       ) >= self.n_step and self.steps_done % self.f_skip == 0:
                    if self.algorithm == 'DQN':
                        self.optimize_DQN()
                    elif self.algorithm == 'DDQN':
                        self.optimize_DDQN()
                if self.steps_done % self.update_f == 0:
                    self.target_net.load_state_dict(
                        self.policy_net.state_dict())

            self.rewards.append(total_reward)
            self.mean_reward = np.mean(self.rewards[-self.n_avg_reward:])
            self.mean_rewards.append(self.mean_reward)
            self.steps.append(self.step)
            # print progress in terminal
            progress = "Episode: " + str(
                episode) + ",\tCurrent mean reward: " + "{:.2f}".format(
                    self.mean_reward
                ) + ',\tBest mean reward: ' + "{:.2f}".format(self.best_reward)
            progress += ",\tCurerent Reward: " + str(
                total_reward) + ",\tTime: " + time.strftime(
                    '%H:%M:%S', time.gmtime(time.time() - start))
            print(progress)
            print(episode,
                  self.mean_reward,
                  self.best_reward,
                  total_reward,
                  time.time() - start,
                  file=logfile)
            logfile.flush()
            if (episode + 1) % self.f_print == 0:
                self.plots()
            # save the best model
            if self.mean_reward > self.best_reward and self.steps_done > self.n_step:
                checkpoint_path = self.save_path + 'model.pth'
                torch.save(self.policy_net.state_dict(), checkpoint_path)
                self.last_saved_reward = self.mean_reward
                self.best_reward = max(self.mean_reward, self.best_reward)
        ###########################

    def optimize_DQN(self):
        # sample
        self.replay_buffer()
        state, action, reward, next_state, dead, done = zip(*self.mini_batch)

        state = self.Tensor(np.float32(state)).permute(0, 3, 1,
                                                       2).to(self.device)
        action = self.LongTensor(action).to(self.device)
        reward = self.Tensor(reward).to(self.device)
        next_state = self.Tensor(np.float32(next_state)).permute(
            0, 3, 1, 2).to(self.device)
        dead = self.Tensor(dead).to(self.device)
        done = self.Tensor(done).to(self.device)

        # Compute Q(s_t, a)
        state_action_values = self.policy_net(state).gather(
            1, action.unsqueeze(1)).squeeze(1)
        # Compute next Q, including the mask
        next_state_values = self.target_net(next_state).detach().max(1)[0]
        # Compute the expected Q value. stop update if done
        expected_state_action_values = reward + (next_state_values *
                                                 self.gamma) * (1 - done)
        # Compute Huber loss
        self.loss = F.smooth_l1_loss(state_action_values,
                                     expected_state_action_values.data)
        # Optimize the model
        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()
        return

    def optimize_DDQN(self):
        # sample
        self.replay_buffer()
        state, action, reward, next_state, dead, done = zip(*self.mini_batch)

        # transfer 1*84*84*4 to 1*4*84*84, which is 0,3,1,2
        state = self.Tensor(np.float32(state)).permute(0, 3, 1,
                                                       2).to(self.device)
        action = self.LongTensor(action).to(self.device)
        reward = self.Tensor(reward).to(self.device)
        next_state = self.Tensor(np.float32(next_state)).permute(
            0, 3, 1, 2).to(self.device)
        dead = self.Tensor(dead).to(self.device)
        done = self.Tensor(done).to(self.device)

        # Compute Q(s_t, a)
        state_action_values = self.policy_net(state).gather(
            1, action.unsqueeze(1)).squeeze(1)
        state_action_values_2 = self.policy_net_2(state).gather(
            1, action.unsqueeze(1)).squeeze(1)
        # Compute next Q, including the mask
        next_state_values = self.target_net(next_state).detach().max(1)[0]
        next_state_values_2 = self.target_net(next_state).detach().max(1)[0]
        next_state_values = torch.min(next_state_values, next_state_values_2)
        # Compute the expected Q value. stop update if done
        expected_state_action_values = reward + (next_state_values *
                                                 self.gamma) * (1 - done)
        # Compute Huber loss
        self.loss = F.smooth_l1_loss(state_action_values,
                                     expected_state_action_values.data)
        self.loss_2 = F.smooth_l1_loss(state_action_values_2,
                                       expected_state_action_values.data)
        # Optimize the model
        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()
        self.optimizer_2.zero_grad()
        self.loss_2.backward()
        self.optimizer_2.step()
        return

    def plots(self):
        fig1 = plt.figure(1)
        plt.clf()
        plt.title('Training_Steps_per_Episode')
        plt.xlabel('Episode')
        plt.ylabel('Steps')
        plt.plot(self.steps)
        fig1.savefig(self.save_path + 'steps.png')

        fig2 = plt.figure(2)
        plt.clf()
        plt.title('Training...')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.plot(self.rewards)

        if len(self.rewards) >= self.n_avg_reward:
            plt.plot(self.mean_rewards)
        fig2.savefig(self.save_path + 'rewards.png')

        rewards = np.array(self.rewards)
        np.save(self.save_path + 'rewards.npy', rewards)
示例#2
0
class Agent_DQN(Agent):
    def __init__(self, env, args):
        """
        Initialize every things you need here.
        For example: building your model
        """

        super(Agent_DQN, self).__init__(env)

        if args.test_dqn:
            # you can load your model here
            print('loading trained model')

        ##################
        # YOUR CODE HERE #
        ##################
        self.env = env
        self.batch_size = BATCH_SIZE
        self.gamma = GAMMA
        self.eps_start = EPS_START
        self.eps_decay = EPS_DECAY
        self.TARGET_UPDATE = TARGET_UPDATE

        self.policy_net = DQN(self.env.action_space.n)
        self.target_net = DQN(self.env.action_space.n)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.policy_net.to(device)
        self.target_net.to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1.5e-4)
        self.memory = ReplayMemory(10000)

        if args.test_dqn:
            # you can load your model here
            print('loading trained model')
            self.policy_net.load_state_dict(
                torch.load(os.path.join('save_dir/'
                                        'model-best.pth'),
                           map_location=torch.device('cpu')))
            self.policy_net.eval()

    def init_game_setting(self):
        """

        Testing function will call this function at the begining of new game
        Put anything you want to initialize if necessary

        """
        ##################
        # YOUR CODE HERE #
        ##################
        pass

    def train(self):
        """
        Implement your training algorithm here
        """
        ##################
        # YOUR CODE HERE #
        ##################
        logfile = open('simple_dqn.log', 'w+')
        step = 0
        num_episodes = 1400000
        for i_episode in range(num_episodes):
            # Initialize the environment and state
            observation = self.env.reset()
            observation = observation.transpose((2, 0, 1))
            observation = observation[np.newaxis, :]
            state = observation
            sum_reward = 0
            for t in count():
                # Select and perform an action
                action = self.make_action(state, test=False)
                next_state, reward, done, _ = self.env.step(action.item())
                reward = np.clip(reward, -1., 1.)
                next_state = next_state.transpose((2, 0, 1))
                next_state = next_state[np.newaxis, :]
                sum_reward += reward
                reward = Tensor([reward])
                step += 1

                # Store the transition in memory
                self.memory.push(torch.from_numpy(state), action,
                                 torch.from_numpy(next_state), reward)

                # Observe new state
                if not done:
                    state = next_state
                else:
                    state = None

                if step >= 5000 and step % 5000 == 0:
                    self.optimize_model()
                    self.target_net.load_state_dict(
                        self.policy_net.state_dict())
                    # Perform one step of the optimization (on the target network)

                if done:
                    print(
                        'resetting env. episode %d \'s step=%d reward total was %d.'
                        % (i_episode + 1, step, sum_reward))
                    print(
                        'resetting env. episode %d \'s step=%d reward total was %d.'
                        % (i_episode + 1, step, sum_reward),
                        file=logfile)
                    logfile.flush()
                    break

            # Update the target network
            # if i_episode % TARGET_UPDATE == 0:
            #     print("Update the target net.")
            #     # print(self.policy_net.state_dict())
            #     self.target_net.load_state_dict(self.policy_net.state_dict())
            if i_episode % 50 == 0:
                checkpoint_path = os.path.join('save_dir', 'model-best.pth')
                torch.save(self.policy_net.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

    def make_action(self, observation, test=True):
        """
        Return predicted action of your agent

        Input:
            observation: np.array
                stack 4 last preprocessed frames, shape: (84, 84, 4)

        Return:
            action: int
                the predicted action from trained model
        """
        ##################
        # YOUR CODE HERE #
        ##################
        global steps_done
        if test:
            observation = observation.transpose((2, 0, 1))
            observation = observation[np.newaxis, :]
            # self.policy_net.eval()
            return self.policy_net(
                Variable(torch.from_numpy(observation),
                         volatile=True).type(FloatTensor)).data.max(1)[1].view(
                             1, 1).item()
        else:
            self.policy_net.eval()
            sample = random.random()
            eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                math.exp(-1. * steps_done / EPS_DECAY)
            steps_done += 1
            if sample > eps_threshold:
                return self.policy_net(
                    Variable(
                        torch.from_numpy(observation),
                        volatile=True).type(FloatTensor)).data.max(1)[1].view(
                            1, 1)
            else:
                return LongTensor([[random.randrange(self.env.action_space.n)]
                                   ])

    def optimize_model(self):
        if len(self.memory) < BATCH_SIZE:
            return
        transitions = self.memory.sample(BATCH_SIZE)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(
            map(lambda s: s is not None, batch.next_state)),
                                      device=device,
                                      dtype=torch.bool)
        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None]).to(device)
        state_batch = torch.cat(batch.state).to(device)
        action_batch = torch.cat(batch.action).to(device)
        reward_batch = torch.cat(batch.reward).to(device)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        state_action_values = self.policy_net(state_batch.float()).gather(
            1, action_batch)

        # Compute V(s_{t+1}) for all next states.
        # Expected values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1)[0].
        # This is merged based on the mask, such that we'll have either the expected
        # state value or 0 in case the state was final.
        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        next_state_values[non_final_mask] = self.target_net(
            non_final_next_states.float()).max(1)[0].detach()
        # Compute the expected Q values
        expected_state_action_values = (next_state_values *
                                        GAMMA) + reward_batch

        # Compute Huber loss
        loss = F.smooth_l1_loss(state_action_values,
                                expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
示例#3
0
class Agent_DQN(Agent):
    def __init__(self, env, args):
        """
        Initialize everything you need here.
        For example: 
            paramters for neural network  
            initialize Q net and target Q net
            parameters for repaly buffer
            parameters for q-learning; decaying epsilon-greedy
            ...
        """

        super(Agent_DQN, self).__init__(env)
        ###########################
        # YOUR IMPLEMENTATION HERE #
        self.env = env
        self.args = args
        self.gamma = self.args.gamma
        self.batch_size = self.args.batch_size
        self.memory_cap = self.args.memory_cap
        self.n_episode = self.args.n_episode
        self.lr = self.args.learning_rate

        self.epsilon = self.args.epsilon
        self.epsilon_decay_window = self.args.epsilon_decay_window
        self.epsilon_min = self.args.epsilon_min
        self.epsilon_decay = (self.epsilon -
                              self.epsilon_min) / self.epsilon_decay_window

        self.n_step = self.args.n_step
        self.f_update = self.args.f_update
        self.load_model = self.args.load_model
        self.action_size = self.args.action_size
        #         self.algorithm = self.args.algorithm

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        print('using device ', torch.cuda.get_device_name(0))
        self.FloatTensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if self.use_cuda else torch.LongTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_cuda else torch.ByteTensor
        self.Tensor = self.FloatTensor

        # Create the policy net and the target net
        self.policy_net = DQN()
        self.policy_net.to(self.device)
        #         if self.algorithm == 'DDQN':
        #             self.policy_net_2 = DQN()
        #             self.policy_net_2.to(self.device)
        self.target_net = DQN()
        self.target_net.to(self.device)
        self.policy_net.train()
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(params=self.policy_net.parameters(),
                                    lr=self.lr)
        # buffer
        self.memory = []

        ##
        self.mean_window = 100
        self.print_frequency = 100
        self.out_dir = "DQN_Module_b1_1/"

        if args.test_dqn:
            #you can load your model here
            print('loading trained model')
            ###########################
            # YOUR IMPLEMENTATION HERE #
            self.policy_net.load_state_dict(
                torch.load('model.pth', map_location=self.device))
            self.target_net.load_state_dict(self.policy_net.state_dict())
            if self.algorithm == 'DDQN':
                self.policy_net_2.load_state_dict(
                    torch.load('model.pth', map_location=self.device))
            self.print_test = True

    def init_game_setting(self):
        """
        Testing function will call this function at the begining of new game
        Put anything you want to initialize if necessary.
        If no parameters need to be initialized, you can leave it as blank.
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #

        ###########################
        pass

    def make_action(self, observation, test=False):
        """
        Return predicted action of your agent
        Input:
            observation: np.array
                stack 4 last preprocessed frames, shape: (84, 84, 4)
        Return:
            action: int
                the predicted action from trained model
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        if test:
            self.epsilon = self.epsilon_min * 0.5
            observation = observation / 255.
        else:
            self.epsilon = max(self.epsilon - self.epsilon_decay,
                               self.epsilon_min)
        if random.random() > self.epsilon:
            observation = self.Tensor(observation.reshape(
                (1, 84, 84, 4))).transpose(1, 3).transpose(2, 3)
            state_action_value = self.policy_net(
                observation).data.cpu().numpy()
            action = np.argmax(state_action_value)
        else:
            action = random.randint(0, self.action_size - 1)
        ###########################
        return action

    def push(self, state, action, reward, next_state, done):
        """ You can add additional arguments as you need. 
        Push new data to buffer and remove the old one if the buffer is full.
        
        Hints:
        -----
            you can consider deque(maxlen = 10000) list
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        if len(self.memory) >= self.memory_cap:
            self.memory.pop(0)
        self.memory.append((state, action, reward, next_state, done))
        ###########################

    def replay_buffer(self):
        """ You can add additional arguments as you need.
        Select batch from buffer.
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        self.mini_batch = random.sample(self.memory, self.batch_size)
        ###########################
        return

    def train(self):
        """
        Implement your training algorithm here
        """
        ###########################
        # YOUR IMPLEMENTATION HERE #
        self.steps_done = 0
        self.steps = []
        self.rewards = []
        self.mean_rewards = []
        self.time = []
        self.best_reward = 0
        self.last_saved_reward = 0
        self.start_time = time.time()
        print('train')
        # continue training from where it stopped
        if self.load_model:
            self.policy_net.load_state_dict(
                torch.load(self.out_dir + 'model.pth',
                           map_location=self.device))
            self.target_net.load_state_dict(self.policy_net.state_dict())
            self.epsilon = self.epsilon_min
            print('Loaded')
        for episode in range(self.n_episode):
            # Initialize the environment and state
            state = self.env.reset() / 255.
            #             self.last_life = 5
            total_reward = 0
            self.step = 0
            done = False

            while (not done) and self.step < 10000:
                # move to next state
                self.step += 1
                self.steps_done += 1
                action = self.make_action(state)
                next_state, reward, done, life = self.env.step(action)
                # lives matter
                #                 self.now_life = life['ale.lives']
                #                 dead = self.now_life < self.last_life
                #                 self.last_life = self.now_life
                next_state = next_state / 255.
                # Store the transition in memory
                self.push(state, action, reward, next_state, done)
                state = next_state
                total_reward += reward

                if done:
                    self.rewards.append(total_reward)
                    self.mean_reward = np.mean(
                        self.rewards[-self.mean_window:])
                    self.mean_rewards.append(self.mean_reward)
                    self.time.append(time.time() - self.start_time)
                    self.steps.append(self.step)

                    # print the process to terminal
                    progress = "episode: " + str(
                        episode) + ",\t epsilon: " + str(
                            self.epsilon
                        ) + ",\t Current mean reward: " + "{:.2f}".format(
                            self.mean_reward)
                    progress += ',\t Best mean reward: ' + "{:.2f}".format(
                        self.best_reward) + ",\t time: " + time.strftime(
                            '%H:%M:%S', time.gmtime(self.time[-1]))
                    print(progress)

                    if episode % self.print_frequency == 0:
                        self.print_and_plot()
                    # save the best model
                    if self.mean_reward > self.best_reward and len(
                            self.memory) >= 5000:
                        print('~~~~~~~~~~<Model updated with best reward = ',
                              self.mean_reward, '>~~~~~~~~~~')
                        checkpoint_path = self.out_dir + 'model.pth'
                        torch.save(self.policy_net.state_dict(),
                                   checkpoint_path)
                        self.last_saved_reward = self.mean_reward
                        self.best_reward = self.mean_reward

                if len(self.memory) >= 5000 and self.steps_done % 4 == 0:
                    #                     if self.algorithm == 'DQN':
                    self.optimize_DQN()
                if self.steps_done % self.f_update == 0:
                    self.target_net.load_state_dict(
                        self.policy_net.state_dict())
#                     print('-------<target net updated at step,',self.steps_done,'>-------')

###########################

    def optimize_DQN(self):
        # sample
        self.replay_buffer()
        state, action, reward, next_state, done = zip(*self.mini_batch)

        # transfer 1*84*84*4 to 1*4*84*84, which is 0,3,1,2
        state = self.Tensor(np.float32(state)).permute(0, 3, 1,
                                                       2).to(self.device)
        action = self.LongTensor(action).to(self.device)
        reward = self.Tensor(reward).to(self.device)
        next_state = self.Tensor(np.float32(next_state)).permute(
            0, 3, 1, 2).to(self.device)

        done = self.Tensor(done).to(self.device)

        # Compute Q(s_t, a)
        state_action_values = self.policy_net(state).gather(
            1, action.unsqueeze(1)).squeeze(1)
        # Compute next Q, including the mask
        next_state_values = self.target_net(next_state).detach().max(1)[0]
        # Compute the expected Q value. stop update if done
        expected_state_action_values = reward + (next_state_values *
                                                 self.gamma) * (1 - done)
        # Compute Huber loss
        self.loss = F.smooth_l1_loss(state_action_values,
                                     expected_state_action_values.data)
        # Optimize the model
        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()
        return

    def print_and_plot(self):
        fig1 = plt.figure(1)
        plt.clf()
        plt.title('Training...')
        plt.xlabel('Episode')
        plt.ylabel('Steps')
        plt.plot(self.steps)
        fig1.savefig(self.out_dir + 'steps.png')

        fig2 = plt.figure(2)
        plt.clf()
        plt.title('Training...')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.plot(self.mean_rewards)
        fig2.savefig(self.out_dir + 'rewards.png')

        fig2 = plt.figure(3)
        plt.clf()
        plt.title('Training...')
        plt.xlabel('Episode')
        plt.ylabel('Time')
        plt.plot(self.time)
        fig2.savefig(self.out_dir + 'time.png')
class Agent_DQN_Trainer(object):
    def __init__(self, env, args):

        # Training Parameters
        self.args = args
        self.env = env
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.gamma = args.gamma_reward_decay
        self.n_actions = env.action_space.n
        self.output_logs = args.output_logs
        self.step = 8e6
        self.curr_step = 0
        self.ckpt_path = args.save_dir
        self.epsilon = args.eps_start
        self.eps_end = args.eps_end
        self.target_update = args.update_target
        self.observe_steps = args.observe_steps
        self.explore_steps = args.explore_steps
        self.saver_steps = args.saver_steps
        self.resume = args.resume
        self.writer = TensorboardSummary(self.args.log_dir).create_summary()
        # Model Settings

        self.cuda = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.policy_net = DQN(4, self.n_actions)
        self.target_net = DQN(4, self.n_actions)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        if self.cuda:
            self.policy_net.to(self.cuda)
            self.target_net.to(self.cuda)

        self.target_net.eval()
        train_params = self.policy_net.parameters()
        self.optimizer = optim.RMSprop(train_params,
                                       self.lr,
                                       momentum=0.95,
                                       eps=0.01)
        self.memory = ReplayMemory(args.replay_memory_size)

        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)

            self.epsilon = checkpoint['epsilon']
            self.curr_step = checkpoint['step']

            self.policy_net.load_state_dict(checkpoint['policy_state_dict'])
            self.target_net.load_state_dict(checkpoint['target_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['episode']))

    def epsilon_greedy_policy(self, observation, nA, test=False):

        observation = to_float(observation).to(self.cuda)
        # print("size of observation->"+str(sys.getsizeof(observation.storage())))
        sample = random.random()

        if test:
            return self.policy_net(observation).max(1)[1].view(1, 1).item()

        if sample <= self.epsilon:
            action = torch.tensor([[random.randrange(self.n_actions)]],
                                  device=self.cuda,
                                  dtype=torch.long)
        else:
            with torch.no_grad():
                action = self.policy_net(observation).max(1)[1].view(1, 1)

        return action

    def optimize_model(self):

        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(tuple(
            map(lambda s: s is not None, batch.next_state)),
                                      device=self.cuda,
                                      dtype=torch.bool)
        non_final_next_states = torch.cat(
            [to_float(s) for s in batch.next_state if s is not None])
        state_batch = torch.cat(
            [to_float(s).to(self.cuda) for s in batch.state])
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = self.policy_net(state_batch).gather(
            1, action_batch)
        next_state_values = torch.zeros(self.batch_size, device=self.cuda)
        next_state_values[non_final_mask] = self.target_net(
            non_final_next_states).max(1)[0].detach()

        expected_state_action_values = (next_state_values *
                                        self.gamma) + reward_batch
        loss = F.smooth_l1_loss(
            state_action_values.float(),
            expected_state_action_values.unsqueeze(1).float())
        self.optimizer.zero_grad()
        loss.backward()

        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)

        self.optimizer.step()
        return loss.item()

    def train(self):

        current_loss = 0
        train_rewards = []
        train_episode_len = 0.0
        file_loss = open(self.output_logs, "a")
        file_loss.write("episode,step,epsilon,reward,loss,length\n")
        print("Training Started")
        episode = 0
        loss = 0.0

        while self.curr_step < self.step:
            state = to_tensor(self.env.reset())

            # * State is in torch.uint8 format , convert before passing to model*#
            done = False
            episode_reward = 0.0
            train_loss = 0
            s = 0  # length of episode
            while not done:
                # self.env.env.render()

                action = self.epsilon_greedy_policy(state, self.n_actions)

                new_state, reward, done, _ = self.env.step(
                    action.item())  # new_state torch.uint8 format
                new_state, reward = to_tensor(new_state).to(
                    self.cuda), torch.tensor([reward], device=self.cuda)
                episode_reward += reward
                self.memory.push(state, action, new_state, reward)

                if (self.curr_step > self.observe_steps) and (
                        self.curr_step % self.args.update_current) == 0:
                    loss = self.optimize_model()
                    train_loss += loss

                print(
                    'Step: %i,  Episode: %i,  Action: %i,  Reward: %.0f,  Epsilon: %.5f, Loss: %.5f'
                    % (self.curr_step, episode, action.item(), reward.item(),
                       self.epsilon, loss),
                    end='\r')

                if self.curr_step > self.observe_steps and self.curr_step % self.target_update == 0:
                    self.target_net.load_state_dict(
                        self.policy_net.state_dict())
                    # TO CHECK APPROXIMATELY HOW MUCH GPU MEMORY OUR REPLAY MEMORY IS CONSUMING
                    print(torch.cuda.get_device_name(0))
                    print('Memory Usage:')
                    print('Allocated:',
                          round(torch.cuda.memory_allocated(0) / 1024**3, 1),
                          'GB')
                    print('Cached:   ',
                          round(torch.cuda.memory_cached(0) / 1024**3, 1),
                          'GB')

                if self.epsilon > self.args.eps_end and self.curr_step > self.observe_steps:
                    interval = self.args.eps_start - self.args.eps_end
                    self.epsilon -= interval / float(self.args.explore_steps)

                self.curr_step += 1
                state = new_state
                s += 1

                if self.curr_step % self.args.saver_steps == 0 and episode != 0 and self.curr_step != 0:
                    k = {
                        'step': self.curr_step,
                        'epsilon': self.epsilon,
                        'episode': episode,
                        'policy_state_dict': self.policy_net.state_dict(),
                        'target_state_dict': self.target_net.state_dict(),
                        'optimizer': self.optimizer.state_dict()
                    }
                    filename = os.path.join(self.ckpt_path, 'ckpt.pth.tar')
                    torch.save(k, filename)

            episode += 1
            train_rewards.append(episode_reward.item())
            train_episode_len += s

            if episode % self.args.num_eval == 0 and episode != 0:
                current_loss = train_loss
                avg_reward_train = np.mean(train_rewards)
                train_rewards = []
                avg_episode_len_train = train_episode_len / float(
                    self.args.num_eval)
                train_episode_len = 0.0
                file_loss.write(
                    str(episode) + "," + str(self.curr_step) + "," +
                    "{:.4f}".format(self.epsilon) + "," +
                    "{:.2f}".format(avg_reward_train) + "," +
                    "{:.4f}".format(current_loss) + "," +
                    "{:.2f}".format(avg_episode_len_train) + "\n")
                file_loss.flush()
                self.writer.add_scalar('train_loss/episode(avg100)',
                                       current_loss, episode)
                self.writer.add_scalar('episode_reward/episode(avg100)',
                                       avg_reward_train, episode)
                self.writer.add_scalar('length of episode/episode(avg100)',
                                       avg_episode_len_train, episode)

            self.writer.add_scalar('train_loss/episode', train_loss, episode)
            self.writer.add_scalar('episode_reward/episode', episode_reward,
                                   episode)
            self.writer.add_scalar('epsilon/num_steps', self.epsilon,
                                   self.curr_step)
            self.writer.add_scalar('length of episode/episode', s, episode)

        print("GAME OVER")