Exemplo n.º 1
0
class DQNAgent:
    def __init__(
            self,
            env,
            memory_size,
            batch_size,
            target_update=100,
            gamma=0.99,
            # replay parameters
            alpha=0.2,
            beta=0.6,
            prior_eps=1e-6,
            # Categorical DQN parameters
            v_min=0,
            v_max=200,
            atom_size=51,
            # N-step Learning
            n_step=3,
            start_train=32,
            save_weights=True,
            log=True,
            lr=0.001,
            seed=0,
            episodes=200):

        self.env = env

        obs_dim = self.env.observation_dim
        action_dim = self.env.action_dim

        self.batch_size = batch_size
        self.target_update = target_update
        self.gamma = gamma
        self.lr = lr
        self.memory_size = memory_size
        self.seed = seed

        # device: cpu / gpu
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # memory for 1-step Learning
        self.beta = beta
        self.prior_eps = prior_eps
        self.memory = PrioritizedReplayBuffer(obs_dim,
                                              memory_size,
                                              batch_size,
                                              alpha=alpha)

        # memory for N-step Learning
        self.use_n_step = True if n_step > 1 else False
        if self.use_n_step:
            self.n_step = n_step
            self.memory_n = ReplayBuffer(obs_dim,
                                         memory_size,
                                         batch_size,
                                         n_step=n_step,
                                         gamma=gamma)

        # Categorical DQN parameters
        self.v_min = v_min
        self.v_max = v_max
        self.atom_size = atom_size
        self.support = torch.linspace(self.v_min, self.v_max,
                                      self.atom_size).to(self.device)

        # networks: dqn, dqn_target
        self.dqn = Network(obs_dim, action_dim, self.atom_size,
                           self.support).to(self.device)
        self.dqn_target = Network(obs_dim, action_dim, self.atom_size,
                                  self.support).to(self.device)

        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        # optimizer
        self.optimizer = optim.Adam(self.dqn.parameters(), lr=self.lr)

        # transition to store in memory
        self.transition = list()

        self.fig, (self.ax1, self.ax2) = plt.subplots(2, figsize=(10, 10))

        self.start_train = start_train

        self.save_weights = save_weights

        self.time = datetime.datetime.now().timetuple()
        self.path = f"weights/{self.time[2]}-{self.time[1]}-{self.time[0]}_{self.time[3]}-{self.time[4]}"

        self.log = log
        self.episode_cnt = 0
        self.episodes = episodes

        if self.save_weights is True:
            self.create_save_directory()

        plt.ion()

    def create_save_directory(self):
        try:
            os.mkdir(self.path)
        except OSError:
            print("Creation of the directory %s failed" % self.path)
        else:
            print("Successfully created the directory %s " % self.path)

    def select_action(self, state):
        """Select an action from the input state."""
        # NoisyNet: no epsilon greedy action selection
        selected_action = self.dqn(torch.FloatTensor(state).to(
            self.device)).argmax()
        selected_action = selected_action.detach().cpu().numpy()

        self.transition = [state, selected_action]

        return selected_action

    def step(self, action):
        """Take an action and return the response of the env."""
        next_state, reward, done = self.env.step(action)

        self.transition += [reward, next_state, done]

        # N-step transition
        if self.use_n_step:
            one_step_transition = self.memory_n.store(*self.transition)
        # 1-step transition
        else:
            one_step_transition = self.transition

        # add a single step transition
        if one_step_transition:
            self.memory.store(*one_step_transition)

        return next_state, reward, done

    def update_model(self):
        """Update the model by gradient descent."""
        # PER needs beta to calculate weights
        samples = self.memory.sample_batch(self.beta)
        weights = torch.FloatTensor(samples["weights"].reshape(-1, 1)).to(
            self.device)
        indices = samples["indices"]

        # 1-step Learning loss
        elementwise_loss = self._compute_dqn_loss(samples, self.gamma)

        # PER: importance sampling before average
        loss = torch.mean(elementwise_loss * weights)

        # N-step Learning loss
        # we are gonna combine 1-step loss and n-step loss so as to
        # prevent high-variance. The original rainbow employs n-step loss only.
        if self.use_n_step:
            gamma = self.gamma**self.n_step
            samples = self.memory_n.sample_batch_from_idxs(indices)
            elementwise_loss_n_loss = self._compute_dqn_loss(samples, gamma)
            elementwise_loss += elementwise_loss_n_loss

            # PER: importance sampling before average
            loss = torch.mean(elementwise_loss * weights)

        self.optimizer.zero_grad()
        loss.backward()
        # print(loss)
        clip_grad_norm_(self.dqn.parameters(), 10.0)
        self.optimizer.step()

        # PER: update priorities
        loss_for_prior = elementwise_loss.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.prior_eps
        self.memory.update_priorities(indices, new_priorities)

        # NoisyNet: reset noise
        self.dqn.reset_noise()
        self.dqn_target.reset_noise()

        return loss.item()

    def train(self, num_frames, plotting_interval=100):
        """Train the agent."""

        if self.log:
            pass
            # config = {'gamma': self.gamma, 'log_interval': plotting_interval, 'learning_rate': self.lr,
            #           'directory': self.path, 'type': 'dqn', 'replay_memory': self.memory_size, 'environment': 'normal', 'seed': self.seed}
            # wandb.init(project='is_os', entity='pydqn', config=config, notes=self.env.reward_function, reinit=True, tags=['report'])
            # wandb.watch(self.dqn)

        self.env.reset()
        state = self.env.get_state()
        won = False
        update_cnt = 0
        losses = []
        scores = []
        score = 0
        frame_cnt = 0
        self.episode_cnt = 0

        for frame_idx in range(1, num_frames + 1):
            frame_cnt += 1
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

            fraction = min(frame_cnt / num_frames, 1.0)
            self.beta = self.beta + fraction * (1.0 - self.beta)

            # if agent has trained 500 frames, terminate
            if frame_cnt == 500:
                done = True

            # if episode ends
            if done:
                if reward > 0:
                    won = True
                self.env.reset()
                state = self.env.get_state()
                self.episode_cnt += 1
                scores.append(score)
                score = 0
                frame_cnt = 0

            # if training is ready
            if len(self.memory) >= self.batch_size:
                loss = self.update_model()
                losses.append(loss)
                update_cnt += 1

                # if hard update is needed
                if update_cnt % self.target_update == 0:
                    self._target_hard_update()

            # plotting
            if frame_idx % plotting_interval == 0:
                self._plot(frame_idx, scores, losses)

            if frame_idx % 1000 == 0:
                torch.save(self.dqn.state_dict(),
                           f'{self.path}/{frame_idx}.tar')
                print(f"model saved at:\n {self.path}/{frame_idx}.tar")

        # wandb.run.summary['won'] = won
        self.env.close()

    def _compute_dqn_loss(self, samples, gamma):
        """Return categorical dqn loss."""
        device = self.device  # for shortening the following lines
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"]).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
        done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)

        # Categorical DQN algorithm
        delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)

        with torch.no_grad():
            # Double DQN
            next_action = self.dqn(next_state).argmax(1)
            next_dist = self.dqn_target.dist(next_state)
            next_dist = next_dist[range(self.batch_size), next_action]

            t_z = reward + (1 - done) * gamma * self.support
            t_z = t_z.clamp(min=self.v_min, max=self.v_max)
            b = (t_z - self.v_min) / delta_z
            l = b.floor().long()
            u = b.ceil().long()

            offset = (torch.linspace(
                0, (self.batch_size - 1) * self.atom_size,
                self.batch_size).long().unsqueeze(1).expand(
                    self.batch_size, self.atom_size).to(self.device))

            proj_dist = torch.zeros(next_dist.size(), device=self.device)
            proj_dist.view(-1).index_add_(0, (l + offset).view(-1),
                                          (next_dist *
                                           (u.float() - b)).view(-1))
            proj_dist.view(-1).index_add_(0, (u + offset).view(-1),
                                          (next_dist *
                                           (b - l.float())).view(-1))

        dist = self.dqn.dist(state)
        log_p = torch.log(dist[range(self.batch_size), action])
        elementwise_loss = -(proj_dist * log_p).sum(1)

        return elementwise_loss

    def _target_hard_update(self):
        """Hard update: target <- local."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _plot(self, frame_cnt, scores, losses):
        self.ax1.cla()
        self.ax1.set_title(
            f'frames: {frame_cnt} score: {np.mean(scores[-10:])}')
        self.ax1.plot(scores[-999:], color='red')
        self.ax2.cla()
        self.ax2.set_title(f'loss: {np.mean(losses[-10:])}')
        self.ax2.plot(losses[-999:], color='blue')
        plt.show()
        plt.pause(0.1)

        # needed for wandb to not log nans
        # if frame_cnt < self.start_train + 11:
        #     loss = 0
        # else:
        #     loss = np.mean(losses[-10:])

        if self.log:
            pass
Exemplo n.º 2
0
class Agent:
    def __init__(
            self,
            env: 'Environment',
            input_frame: ('int: the number of channels of input image'),
            input_dim: (
                'int: the width and height of pre-processed input image'),
            input_type: ('str: the type of input dimension'),
            num_frames: ('int: Total number of frames'),
            skipped_frame: ('int: The number of skipped frames'),
            eps_decay: ('float: Epsilon Decay_rate'),
            gamma: ('float: Discount Factor'),
            target_update_freq: ('int: Target Update Frequency (by frames)'),
            update_type: (
                'str: Update type for target network. Hard or Soft') = 'hard',
            soft_update_tau: ('float: Soft update ratio') = None,
            batch_size: ('int: Update batch size') = 32,
            buffer_size: ('int: Replay buffer size') = 1000000,
            alpha: (
                'int: Hyperparameter for how large prioritization is applied'
            ) = 0.5,
            beta:
        ('int: Hyperparameter for the annealing factor of importance sampling'
         ) = 0.5,
            epsilon_for_priority:
        ('float: Hyperparameter for adding small increment to the priority'
         ) = 1e-6,
            update_start_buffer_size: (
                'int: Update starting buffer size') = 50000,
            learning_rate: ('float: Learning rate') = 0.0004,
            eps_min: ('float: Epsilon Min') = 0.1,
            eps_max: ('float: Epsilon Max') = 1.0,
            device_num: ('int: GPU device number') = 0,
            rand_seed: ('int: Random seed') = None,
            plot_option: ('str: Plotting option') = False,
            model_path: ('str: Model saving path') = './'):

        self.action_dim = env.action_space.n
        self.device = torch.device(
            f'cuda:{device_num}' if torch.cuda.is_available() else 'cpu')
        self.model_path = model_path

        self.env = env
        self.input_frames = input_frame
        self.input_dim = input_dim
        self.num_frames = num_frames
        self.skipped_frame = skipped_frame
        self.epsilon = eps_max
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.gamma = gamma
        self.target_update_freq = target_update_freq
        self.update_cnt = 0
        self.update_type = update_type
        self.tau = soft_update_tau
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.update_start = update_start_buffer_size
        self.seed = rand_seed
        self.plot_option = plot_option

        # hyper parameters for PER
        self.alpha = alpha
        self.beta = beta
        self.beta_step = (1.0 - beta) / num_frames
        self.epsilon_for_priority = epsilon_for_priority

        if input_type == '1-dim':
            self.q_current = QNetwork_1dim(self.input_dim,
                                           self.action_dim).to(self.device)
            self.q_target = QNetwork_1dim(self.input_dim,
                                          self.action_dim).to(self.device)
        else:
            self.q_current = QNetwork(
                (self.input_frames, self.input_dim, self.input_dim),
                self.action_dim).to(self.device)
            self.q_target = QNetwork(
                (self.input_frames, self.input_dim, self.input_dim),
                self.action_dim).to(self.device)
        self.q_target.load_state_dict(self.q_current.state_dict())
        self.q_target.eval()
        self.optimizer = optim.Adam(self.q_current.parameters(),
                                    lr=learning_rate)

        if input_type == '1-dim':
            self.memory = PrioritizedReplayBuffer(self.buffer_size,
                                                  self.input_dim,
                                                  self.batch_size, self.alpha,
                                                  input_type)
        else:
            self.memory = PrioritizedReplayBuffer(
                self.buffer_size,
                (self.input_frames, self.input_dim, self.input_dim),
                self.batch_size, self.alpha, input_type)

    def select_action(
        self, state:
        'Must be pre-processed in the same way while updating current Q network. See def _compute_loss'
    ):

        if np.random.random() < self.epsilon:
            return np.zeros(self.action_dim), self.env.action_space.sample()
        else:
            state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
            Qs = self.q_current(state)
            action = Qs.argmax()
            return Qs.detach().cpu().numpy(), action.detach().item()

    def processing_resize_and_gray(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)  # Pure
        # frame = cv2.cvtColor(frame[:177, 32:128, :], cv2.COLOR_RGB2GRAY) # Boxing
        # frame = cv2.cvtColor(frame[2:198, 7:-7, :], cv2.COLOR_RGB2GRAY) # Breakout
        frame = cv2.resize(frame,
                           dsize=(self.input_dim, self.input_dim)).reshape(
                               self.input_dim, self.input_dim).astype(np.uint8)
        return frame

    def get_state(self, state, action, skipped_frame=0):
        '''
        num_frames: how many frames to be merged
        input_size: hight and width of input resized image
        skipped_frame: how many frames to be skipped
        '''
        next_state = np.zeros(
            (self.input_frames, self.input_dim, self.input_dim))
        for i in range(len(state) - 1):
            next_state[i] = state[i + 1]

        rewards = 0
        dones = 0
        for j in range(skipped_frame):
            state, reward, done, _ = self.env.step(action)
            rewards += reward
            dones += int(done)
        state, reward, done, _ = self.env.step(action)
        next_state[-1] = self.processing_resize_and_gray(state)
        rewards += reward
        dones += int(done)
        return rewards, next_state, dones

    def get_state_1dim(self, state, action, skipped_frame=0):
        '''
        num_frames: how many frames to be merged
        input_size: hight and width of input resized image
        skipped_frame: how many frames to be skipped
        '''
        next_state = np.zeros((self.input_frames, self.input_dim))
        for i in range(len(state) - 1):
            next_state[i] = state[i + 1]

        rewards = 0
        dones = 0
        for _ in range(skipped_frame):
            state, reward, done, _ = self.env.step(action)
            rewards += reward
            dones += int(done)
        state, reward, done, _ = self.env.step(action)
        next_state[-1] = state
        rewards += reward
        dones += int(done)
        return rewards, next_state, dones

    def get_init_state(self):

        init_state = np.zeros(
            (self.input_frames, self.input_dim, self.input_dim))
        init_frame = self.env.reset()
        init_state[0] = self.processing_resize_and_gray(init_frame)

        for i in range(1, self.input_frames):
            action = self.env.action_space.sample()
            for j in range(self.skipped_frame):
                state, _, _, _ = self.env.step(action)
            state, _, _, _ = self.env.step(action)
            init_state[i] = self.processing_resize_and_gray(state)
        return init_state

    def get_init_state_1dim(self):

        init_state = np.zeros((self.input_frames, self.input_dim))
        init_frame = self.env.reset()
        init_state[0] = init_frame

        for i in range(1, self.input_frames):
            action = self.env.action_space.sample()
            for j in range(self.skipped_frame):
                state, _, _, _ = self.env.step(action)
            state, _, _, _ = self.env.step(action)
            init_state[i] = state
        return init_state

    def store(self, state, action, reward, next_state, done):
        self.memory.store(state, action, reward, next_state, done)

    def update_current_q_net(self):
        '''The diffent method between Dueling and PER in the Agent class'''
        batch = self.memory.batch_load(self.beta)
        weights = torch.FloatTensor(batch['weights'].reshape(-1, 1)).to(
            self.device)
        sample_wise_loss = self._compute_loss(
            batch)  # PER: shape of loss -> (batch, 1)
        loss = torch.mean(sample_wise_loss * weights)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # For PER: update priorities of the samples.
        sample_wise_loss = sample_wise_loss.detach().cpu().numpy()
        batch_priorities = sample_wise_loss + self.epsilon_for_priority
        self.memory.update_priorities(batch['indices'], batch_priorities)

        return loss.item()

    def target_soft_update(self):
        for target_param, current_param in zip(self.q_target.parameters(),
                                               self.q_current.parameters()):
            target_param.data.copy_(self.tau * current_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def target_hard_update(self):
        self.update_cnt = (self.update_cnt + 1) % self.target_update_freq
        if self.update_cnt == 0:
            self.q_target.load_state_dict(self.q_current.state_dict())

    def train(self):
        tic = time.time()
        losses = []
        scores = []
        epsilons = []
        avg_scores = [[-1000]]

        score = 0

        print("Storing initial buffer..")
        # state = self.get_init_state()
        # state = self.get_init_state_1dim()
        state = self.env.reset()
        for frame_idx in range(1, self.update_start + 1):
            _, action = self.select_action(state)
            next_state, reward, done, _ = self.env.step(action)
            self.store(state, action, reward, next_state, done)
            state = next_state
            if done: state = self.env.reset()

        print("Done. Start learning..")
        history_store = []
        for frame_idx in range(1, self.num_frames + 1):
            Qs, action = self.select_action(state)
            next_state, reward, done, _ = self.env.step(action)
            self.store(state, action, reward, next_state, done)
            history_store.append([state, Qs, action, reward, next_state, done])
            loss = self.update_current_q_net()

            if self.update_type == 'hard': self.target_hard_update()
            elif self.update_type == 'soft': self.target_soft_update()

            score += reward
            losses.append(loss)

            if done:
                scores.append(score)
                if np.mean(scores[-10:]) > max(avg_scores):
                    torch.save(
                        self.q_current.state_dict(),
                        self.model_path + '{}_Score:{}.pt'.format(
                            frame_idx, np.mean(scores[-10:])))
                    training_time = round((time.time() - tic) / 3600, 1)
                    np.save(
                        self.model_path +
                        '{}_history_Score_{}_{}hrs.npy'.format(
                            frame_idx, score, training_time),
                        np.array(history_store))
                    print(
                        "          | Model saved. Recent scores: {}, Training time: {}hrs"
                        .format(scores[-10:], training_time),
                        ' /'.join(os.getcwd().split('/')[-3:]))
                avg_scores.append(np.mean(scores[-10:]))

                if self.plot_option == 'inline':
                    scores.append(score)
                    epsilons.append(self.epsilon)
                    self._plot(frame_idx, scores, losses, epsilons)
                elif self.plot_option == 'wandb':
                    Q_mean = np.mean(np.array(history_store)[:, 1])
                    wandb.log({
                        'Score': score,
                        'loss(10 frames avg)': np.mean(losses[-10:]),
                        'Q (mean)': Q_mean,
                        'Epsilon': self.epsilon,
                        'beta': self.beta
                    })
                    print(score, end='\r')
                else:
                    print(score, end='\r')

                score = 0
                state = self.env.reset()
                history_store = []
            else:
                state = next_state

            self._epsilon_step()

            # self.beta = min(self.beta+self.beta_step, 1.0) # for PER. beta is increased linearly up to 1.0
            fraction = min(frame_idx / self.num_frames, 1.0)
            self.beta = self.beta + fraction * (1.0 - self.beta)

        print("Total training time: {}(hrs)".format(
            (time.time() - tic) / 3600))

    def _epsilon_step(self):
        ''' Epsilon decay control '''
        eps_decay_init = 1 / 1200000
        eps_decay = [
            eps_decay_init, eps_decay_init / 2.5, eps_decay_init / 3.5,
            eps_decay_init / 5.5
        ]

        if self.epsilon > 0.30:
            self.epsilon = max(self.epsilon - eps_decay[0], 0.1)
        elif self.epsilon > 0.27:
            self.epsilon = max(self.epsilon - eps_decay[1], 0.1)
        elif self.epsilon > 1.7:
            self.epsilon = max(self.epsilon - eps_decay[2], 0.1)
        else:
            self.epsilon = max(self.epsilon - eps_decay[3], 0.1)

    def _compute_loss(self, batch: "Dictionary (S, A, R', S', Dones)"):
        # If normalization is used, it must be applied to 'state' and 'next_state' here. ex) state/255
        states = torch.FloatTensor(batch['states']).to(self.device)
        next_states = torch.FloatTensor(batch['next_states']).to(self.device)
        actions = torch.LongTensor(batch['actions'].reshape(-1,
                                                            1)).to(self.device)
        rewards = torch.FloatTensor(batch['rewards'].reshape(-1, 1)).to(
            self.device)
        dones = torch.FloatTensor(batch['dones'].reshape(-1,
                                                         1)).to(self.device)

        current_q = self.q_current(states).gather(1, actions)
        # The next line is the only difference from Vanila DQN.
        next_q = self.q_target(next_states).gather(
            1,
            self.q_current(next_states).argmax(axis=1, keepdim=True)).detach()
        mask = 1 - dones
        target = (rewards + (mask * self.gamma * next_q)).to(self.device)

        # For PER, the shape of loss is (batch, 1). Therefore, using "reduction='none'" option.
        sample_wise_loss = F.smooth_l1_loss(current_q,
                                            target,
                                            reduction="none")
        return sample_wise_loss

    def _plot(self, frame_idx, scores, losses, epsilons):
        clear_output(True)
        plt.figure(figsize=(20, 5), facecolor='w')
        plt.subplot(131)
        plt.title('frame %s. score: %s' % (frame_idx, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.subplot(132)
        plt.title('loss')
        plt.plot(losses)
        plt.subplot(133)
        plt.title('epsilons')
        plt.plot(epsilons)
        plt.show()