Пример #1
0
class A3C():
    '''Implementation of N-step Asychronous Advantage Actor Critic'''
    def __init__(self, args, env, train=True):
        self.args = args
        self.set_random_seeds()
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Create the environment.
        self.env = gym.make(env)
        self.environment_name = env

        # Setup model.
        self.policy = ActorCritic(4, self.env.action_space.n)
        self.policy.apply(self.initialize_weights)

        # Setup critic model.
        self.critic = ActorCritic(4, self.env.action_space.n)
        self.critic.apply(self.initialize_weights)

        # Setup optimizer.
        self.eps = 1e-10  # To avoid divide-by-zero error.
        self.policy_optimizer = optim.Adam(self.policy.parameters(),
                                           lr=args.policy_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Model weights path.
        self.timestamp = datetime.now().strftime(
            'a2c-breakout-%Y-%m-%d_%H-%M-%S')
        self.weights_path = 'models/%s/%s' % (self.environment_name,
                                              self.timestamp)

        # Load pretrained weights.
        if args.weights_path: self.load_model()
        self.policy.to(self.device)
        self.critic.to(self.device)

        # Video render mode.
        if args.render:
            self.policy.eval()
            self.generate_episode(render=True)
            self.plot()
            return

        # Data for plotting.
        self.rewards_data = []  # n * [epoch, mean(returns), std(returns)]

        # Network training mode.
        if train:
            # Tensorboard logging.
            self.logdir = 'logs/%s/%s' % (self.environment_name,
                                          self.timestamp)
            self.summary_writer = SummaryWriter(self.logdir)

            # Save hyperparameters.
            with open(self.logdir + '/training_parameters.json', 'w') as f:
                json.dump(vars(self.args), f, indent=4)

    def initialize_weights(self, layer):
        if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)

    def set_random_seeds(self):
        torch.manual_seed(self.args.random_seed)
        np.random.seed(self.args.random_seed)
        torch.backends.cudnn.benchmark = True

    def save_model(self, epoch):
        '''Helper function to save model state and weights.'''
        if not os.path.exists(self.weights_path):
            os.makedirs(self.weights_path)
        torch.save(
            {
                'policy_state_dict': self.policy.state_dict(),
                'policy_optimizer': self.policy_optimizer.state_dict(),
                'critic_state_dict': self.critic.state_dict(),
                'critic_optimizer': self.critic_optimizer.state_dict(),
                'rewards_data': self.rewards_data,
                'epoch': epoch
            }, os.path.join(self.weights_path, 'model_%d.h5' % epoch))

    def load_model(self):
        '''Helper function to load model state and weights. '''
        if os.path.isfile(self.args.weights_path):
            print('=> Loading checkpoint', self.args.weights_path)
            self.checkpoint = torch.load(self.args.weights_path)
            self.policy.load_state_dict(self.checkpoint['policy_state_dict'])
            self.policy_optimizer.load_state_dict(
                self.checkpoint['policy_optimizer'])
            self.critic.load_state_dict(self.checkpoint['critic_state_dict'])
            self.critic_optimizer.load_state_dict(
                self.checkpoint['critic_optimizer'])
            self.rewards_data = self.checkpoint['rewards_data']
        else:
            raise Exception('No checkpoint found at %s' %
                            self.args.weights_path)

    def train(self):
        '''Trains the model on a single episode using REINFORCE.'''
        for epoch in range(self.args.num_episodes):
            # Generate epsiode data.
            returns, log_probs, value_function, train_rewards = self.generate_episode(
            )
            self.summary_writer.add_scalar('train/cumulative_rewards',
                                           train_rewards, epoch)
            self.summary_writer.add_scalar('train/trajectory_length',
                                           returns.size()[0], epoch)

            # Compute loss and policy gradient.
            self.policy_optimizer.zero_grad()
            policy_loss = ((returns - value_function.detach()) *
                           -log_probs).mean()
            policy_loss.backward()
            self.policy_optimizer.step()

            self.critic_optimizer.zero_grad()
            critic_loss = F.mse_loss(returns, value_function)
            critic_loss.backward()
            self.critic_optimizer.step()

            # Test the model.
            if epoch % self.args.test_interval == 0:
                self.policy.eval()
                print('\nTesting')
                rewards = [
                    self.generate_episode(test=True)
                    for epoch in range(self.args.test_episodes)
                ]
                rewards_mean, rewards_std = np.mean(rewards), np.std(rewards)
                print(
                    'Test Rewards (Mean): %.3f | Test Rewards (Std): %.3f\n' %
                    (rewards_mean, rewards_std))
                self.rewards_data.append([epoch, rewards_mean, rewards_std])
                self.summary_writer.add_scalar('test/rewards_mean',
                                               rewards_mean, epoch)
                self.summary_writer.add_scalar('test/rewards_std', rewards_std,
                                               epoch)
                self.policy.train()

            # Logging.
            if epoch % self.args.log_interval == 0:
                print(
                    'Epoch: {0:05d}/{1:05d} | Policy Loss: {2:.3f} | Value Loss: {3:.3f}'
                    .format(epoch, self.args.num_episodes, policy_loss,
                            critic_loss))
                self.summary_writer.add_scalar('train/policy_loss',
                                               policy_loss, epoch)
                self.summary_writer.add_scalar('train/critic_loss',
                                               critic_loss, epoch)

            # Save the model.
            if epoch % self.args.save_interval == 0:
                self.save_model(epoch)

        self.save_model(epoch)
        self.summary_writer.close()

    def generate_episode(self,
                         gamma=0.99,
                         test=False,
                         render=False,
                         max_iters=10000):
        '''
        Generates an episode by executing the current policy in the given env.
        Returns:
        - a list of states, indexed by time epoch
        - a list of actions, indexed by time epoch
        - a list of cumulative discounted returns, indexed by time epoch
        '''
        iters = 0
        done = False
        state = self.env.reset()

        # Set video save path if render enabled.
        if render:
            save_path = 'videos/%s/epoch-%s' % (self.environment_name,
                                                self.checkpoint['epoch'])
            if not os.path.exists(save_path): os.makedirs(save_path)
            monitor = gym.wrappers.Monitor(self.env, save_path, force=True)

        batches = []
        states = [torch.zeros(84, 84, device=self.device).float()] * 3
        rewards, returns = [], []
        actions, log_probs = [], []

        while not done:
            # Run policy on current state to log probabilities of actions.
            states.append(
                torch.tensor(preprocess(state),
                             device=self.device).float().squeeze(0))
            batches.append(torch.stack(states[-4:]))
            action_probs = self.policy.forward(
                batches[-1].unsqueeze(0)).squeeze(0)

            # Sample action from the log probabilities.
            if test and self.args.det_eval: action = torch.argmax(action_probs)
            else:
                action = torch.argmax(
                    torch.distributions.Multinomial(
                        logits=action_probs).sample())
            actions.append(action)
            log_probs.append(action_probs[action])

            # Run simulation with current action to get new state and reward.
            if render: monitor.render()
            state, reward, done, _ = self.env.step(action.cpu().numpy())
            rewards.append(reward)

            # Break if the episode takes too long.
            iters += 1
            if iters > max_iters: break

        # Save video and close rendering.
        cum_rewards = np.sum(rewards)
        if render:
            monitor.close()
            print('\nCumulative Rewards:', cum_rewards)
            return

        # Return cumulative rewards for test mode.
        if test: return cum_rewards

        # Flip rewards from T-1 to 0.
        rewards = np.array(rewards) / self.args.reward_normalizer

        # Compute value.
        values = []
        minibatches = torch.split(torch.stack(batches), 256)
        for minibatch in minibatches:
            values.append(
                self.critic.forward(minibatch, action=False).squeeze(1))
        values = torch.cat(values)
        discounted_values = values * gamma**self.args.n

        # Compute the cumulative discounted returns.
        n_step_rewards = np.zeros((1, self.args.n))
        for i in reversed(range(rewards.shape[0])):
            if i + self.args.n >= rewards.shape[0]:
                V_end = 0
            else:
                V_end = discounted_values[i + self.args.n]
            n_step_rewards[0, :-1] = n_step_rewards[0, 1:] * gamma
            n_step_rewards[0, -1] = rewards[i]

            n_step_return = torch.tensor(
                n_step_rewards.sum(), device=self.device).unsqueeze(0) + V_end
            returns.append(n_step_return)

        # Normalize returns.
        # returns = torch.stack(returns)
        # mean_return, std_return = returns.mean(), returns.std()
        # returns = (returns - mean_return) / (std_return + self.eps)

        return torch.stack(returns[::-1]).detach().squeeze(1), torch.stack(
            log_probs), values.squeeze(), cum_rewards

    def plot(self):
        # Save the plot.
        filename = os.path.join(
            'plots',
            *self.args.weights_path.split('/')[-2:]).replace('.h5', '.png')
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))

        # Make error plot with mean, std of rewards.
        data = np.asarray(self.rewards_data)
        plt.errorbar(data[:, 0],
                     data[:, 1],
                     data[:, 2],
                     lw=2.5,
                     elinewidth=1.5,
                     ecolor='grey',
                     barsabove=True,
                     capthick=2,
                     capsize=3)
        plt.title('Cumulative Rewards (Mean/Std) Plot for A3C Algorithm')
        plt.xlabel('Number of Episodes')
        plt.ylabel('Cumulative Rewards')
        plt.grid()
        plt.savefig(filename, dpi=300)
        plt.show()
Пример #2
0
class Agent():
    def __init__(self,
                 env,
                 use_cnn=False,
                 learning_rate=3e-4,
                 gamma=0.99,
                 buffer_size=10000):
        self.env = env
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.num_in = self.env.observation_space.shape[0]
        self.num_out = self.env.action_space.n

        self.ac_net = ActorCritic(self.env.observation_space.shape[0],
                                  self.env.action_space.n)
        self.ac_optimizer = optim.Adam(self.ac_net.parameters(),
                                       lr=learning_rate)

    def update(self, rewards, values, next_value, log_probs, entropy):
        qvals = np.zeros(len(values))
        qval = next_value
        for t in reversed(range(len(rewards))):
            qval = rewards[t] + self.gamma * qval
            qvals[t] = qval

        values = torch.FloatTensor(values)
        qvals = torch.FloatTensor(qvals)
        log_probs = torch.stack(log_probs)

        advantage = qvals - values
        actor_loss = (-log_probs * advantage).mean()
        critic_loss = advantage.pow(2).mean()
        ac_loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy

        self.ac_optimizer.zero_grad()
        ac_loss.backward()
        self.ac_optimizer.step()

    def get_ac_output(self, state):
        state = Variable(torch.from_numpy(state).float().unsqueeze(0))
        value, policy_dist = self.ac_net.forward(state)
        action = np.random.choice(self.num_out,
                                  p=policy_dist.detach().numpy().squeeze(0))

        return action, policy_dist, value

    def train(self, max_episode, max_step):
        for episode in range(max_episode):
            rewards = []
            values = []
            log_probs = []
            entropy_term = 0
            episode_reward = 0

            state = self.env.reset()
            for steps in range(max_step):
                action, policy_dist, value = self.get_ac_output(state)
                new_state, reward, done, _ = self.env.step(action)

                log_prob = torch.log(policy_dist.squeeze(0)[action])
                entropy = -torch.sum(
                    policy_dist.mean() * torch.log(policy_dist))

                rewards.append(reward)
                values.append(value.detach().numpy()[0])
                log_probs.append(log_prob)
                entropy_term += entropy
                state = new_state
                episode_reward += reward

                if done:
                    if episode % 10 == 0:
                        print("episode: " + str(episode) + ": " +
                              str(episode_reward))
                    break

            _, _, next_value = self.get_ac_output(state)
            self.update(rewards, values, next_value, log_probs, entropy_term)