Beispiel #1
0
class DDPG:
    def __init__(self, env_fn, save_dir, ac_kwargs=dict(), seed=0, tensorboard_logdir = None,
         replay_size=int(1e6), gamma=0.99, 
         tau=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, start_steps=10000, 
         update_after=1000, update_every=50, act_noise=0.1, num_test_episodes=10, 
         max_ep_len=1000, logger_kwargs=dict(), save_freq=1, ngpu=1):    
        '''
        Deep Deterministic Policy Gradients (DDPG)
        Args:
            env_fn: function to create the gym environment
            save_dir: path to save directory
            actor_critic: Class for the actor-critic pytorch module
            ac_kwargs (dict): any keyword argument for the actor_critic
                        (1) hidden_sizes=(256, 256)
                        (2) activation=nn.ReLU
                        (3) device='cpu'
            seed (int): seed for random generators
            replay_size (int): Maximum length of replay buffer.
            gamma (float): Discount factor. (Always between 0 and 1.)
            tau (float): Interpolation factor in polyak averaging for target 
                networks.
            pi_lr (float): Learning rate for policy.
            q_lr (float): Learning rate for Q-networks.
            batch_size (int): Minibatch size for SGD.
            start_steps (int): Number of steps for uniform-random action selection,
                before running real policy. Helps exploration.
            update_after (int): Number of env interactions to collect before
                starting to do gradient descent updates. Ensures replay buffer
                is full enough for useful updates.
            update_every (int): Number of env interactions that should elapse
                between gradient descent updates. Note: Regardless of how long 
                you wait between updates, the ratio of env steps to gradient steps 
                is locked to 1.
            act_noise (float): Stddev for Gaussian exploration noise added to 
                policy at training time. (At test time, no noise is added.)
            num_test_episodes (int): Number of episodes to test the deterministic
                policy at the end of each epoch.
            max_ep_len (int): Maximum length of trajectory / episode / rollout.
            logger_kwargs (dict): Keyword args for Logger. 
                        (1) output_dir = None
                        (2) output_fname = 'progress.pickle'
            save_freq (int): How often (in terms of gap between episodes) to save
                the current policy and value function.
        '''
        # logger stuff
        self.logger = Logger(**logger_kwargs)

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.env = env_fn()

        # Action Limit for clamping
        self.act_limit = self.env.action_space.high[0]

        # Create actor-critic module
        self.ngpu = ngpu
        self.actor_critic = get_actor_critic_module(ac_kwargs, 'ddpg')
        self.ac_kwargs = ac_kwargs
        self.ac = self.actor_critic(self.env.observation_space, self.env.action_space, device=self.device, ngpu=self.ngpu, **ac_kwargs)
        self.ac_targ = deepcopy(self.ac)

        # Freeze target networks with respect to optimizers
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        
        # Experience buffer
        self.replay_size = replay_size
        self.replay_buffer = ReplayBuffer(int(replay_size))

        # Set up optimizers for actor and critic
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=pi_lr)
        self.q_optimizer = Adam(self.ac.q.parameters(), lr=q_lr)

        self.gamma = gamma
        self.tau = tau
        self.act_noise = act_noise
        # self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]
        self.num_test_episodes = num_test_episodes
        self.max_ep_len = self.env.spec.max_episode_steps if self.env.spec.max_episode_steps is not None else max_ep_len
        self.start_steps = start_steps
        self.update_after = update_after
        self.update_every = update_every
        self.batch_size = batch_size
        self.save_freq = save_freq

        self.best_mean_reward = -np.inf
        self.save_dir = save_dir
        self.tensorboard_logdir = tensorboard_logdir

    def reinit_network(self):
        '''
        Re-initialize network weights and optimizers for a fresh agent to train
        '''        
        
        # Create actor-critic module
        self.best_mean_reward = -np.inf
        self.ac = self.actor_critic(self.env.observation_space, self.env.action_space, device=self.device, ngpu=self.ngpu, **self.ac_kwargs)
        self.ac_targ = deepcopy(self.ac)

        # Freeze target networks with respect to optimizers
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        
        # Experience buffer
        self.replay_buffer = ReplayBuffer(int(self.replay_size))

        # Set up optimizers for actor and critic
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=self.pi_lr)
        self.q_optimizer = Adam(self.ac.q.parameters(), lr=self.q_lr)

    def update(self, experiences, timestep):
        '''
        Do gradient updates for actor-critic models
        Args:
            experiences: sampled s, a, r, s', terminals from replay buffer.
        '''
        # Get states, action, rewards, next_states, terminals from experiences
        self.ac.train()
        self.ac_targ.train()
        states, actions, rewards, next_states, terminals = experiences
        states = states.to(self.device)
        next_states = next_states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        terminals = terminals.to(self.device)

        # --------------------- Optimizing critic ---------------------
        self.q_optimizer.zero_grad()
        # calculating q loss
        Q_values = self.ac.q(states, actions)
        with torch.no_grad():
            next_actions = self.ac_targ.pi(next_states)
            next_Q = self.ac_targ.q(next_states, next_actions) * (1-terminals)
            Qprime = rewards + (self.gamma * next_Q)
        
        # MSE loss
        loss_q = ((Q_values-Qprime)**2).mean()
        loss_info = dict(Qvals=Q_values.detach().cpu().numpy().tolist())

        loss_q.backward()
        self.q_optimizer.step()

        # --------------------- Optimizing actor ---------------------
        # Freeze Q-network so no computational resources is wasted in computing gradients
        for p in self.ac.q.parameters():
            p.requires_grad = False

        self.pi_optimizer.zero_grad()
        loss_pi = -self.ac.q(states, self.ac.pi(states)).mean()
        loss_pi.backward()
        self.pi_optimizer.step()

        # Unfreeze Q-network for next update step
        for p in self.ac.q.parameters():
            p.requires_grad = True
            
        # Record loss q and loss pi and qvals in the form of loss_info
        self.logger.store(LossQ=loss_q.item(), LossPi=loss_pi.item(), **loss_info)
        self.tensorboard_logger.add_scalar("loss/q_loss", loss_q.item(), timestep)
        self.tensorboard_logger.add_scalar("loss/pi_loss", loss_pi.item(), timestep)
        # update target networks
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                p_targ.data.mul_(self.tau)
                p_targ.data.add_((1-self.tau)*p.data)

    def get_action(self, obs, noise_scale):
        '''
        Input the current observation into the actor network to calculate action to take.
        Args:
            obs (numpy ndarray): Current state of the environment. Only 1 state, not a batch of states
            noise_scale (float): Stddev for Gaussian exploration noise
        Return:
            Action (numpy ndarray): Scaled action that is clipped to environment's action limits
        '''
        self.ac.eval()
        self.ac_targ.eval()
        obs = torch.as_tensor([obs], dtype=torch.float32).to(self.device)
        action = self.ac.act(obs).squeeze()
        if len(action.shape) == 0:
            action = np.array([action])
        action += noise_scale*np.random.randn(self.act_dim)
        return np.clip(action, -self.act_limit, self.act_limit)

    def evaluate_agent(self):
        '''
        Run the current model through test environment for <self.num_test_episodes> episodes
        without noise exploration, and store the episode return and length into the logger.
        
        Used to measure how well the agent is doing.
        '''
        self.env.training=False
        for i in range(self.num_test_episodes):
            state, done, ep_ret, ep_len = self.env.reset(), False, 0, 0
            while not (done or (ep_len==self.max_ep_len)):
                # Take deterministic action with 0 noise added
                state, reward, done, _ = self.env.step(self.get_action(state, 0))
                ep_ret += reward
                ep_len += 1
            self.logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
        self.env.training=True

    def save_weights(self, best=False, fname=None):
        '''
        save the pytorch model weights of ac and ac_targ
        as well as pickling the environment to preserve any env parameters like normalisation params
        Args:
            best(bool): if true, save it as best.pth
            fname(string): if specified, save it as <fname>
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"
        
        print('saving checkpoint...')
        checkpoint = {
            'ac': self.ac.state_dict(),
            'ac_target': self.ac_targ.state_dict(),
            'pi_optimizer': self.pi_optimizer.state_dict(),
            'q_optimizer': self.q_optimizer.state_dict()
        }
        torch.save(checkpoint, os.path.join(self.save_dir, _fname))
        self.replay_buffer.save(os.path.join(self.save_dir, "replay_buffer.pickle"))
        self.env.save(os.path.join(self.save_dir, "env.json"))
        print(f"checkpoint saved at {os.path.join(self.save_dir, _fname)}")

    def load_weights(self, best=True, load_buffer=True):
        '''
        Load the model weights and replay buffer from self.save_dir
        Args:
            best (bool): If True, save from the weights file with the best mean episode reward
            load_buffer (bool): If True, load the replay buffer from the pickled file
        '''
        if best:
            fname = "best.pth"
        else:
            fname = "model_weights.pth"
        checkpoint_path = os.path.join(self.save_dir, fname)
        if os.path.isfile(checkpoint_path):
            if load_buffer:
                self.replay_buffer.load(os.path.join(self.save_dir, "replay_buffer.pickle"))
            key = 'cuda' if torch.cuda.is_available() else 'cpu'
            checkpoint = torch.load(checkpoint_path, map_location=key)
            self.ac.load_state_dict(sanitise_state_dict(checkpoint['ac'], self.ngpu>1))
            self.ac_targ.load_state_dict(sanitise_state_dict(checkpoint['ac_target'], self.ngpu>1))
            self.pi_optimizer.load_state_dict(sanitise_state_dict(checkpoint['pi_optimizer'], self.ngpu>1))
            self.q_optimizer.load_state_dict(sanitise_state_dict(checkpoint['q_optimizer'], self.ngpu>1))

            env_path = os.path.join(self.save_dir, "env.json")
            if os.path.isfile(env_path):
                self.env = self.env.load(env_path)
                print("Environment loaded")
            
            print('checkpoint loaded at {}'.format(checkpoint_path))
        else:
            raise OSError("Checkpoint file not found.")    

    def learn_one_trial(self, timesteps, trial_num):
        state, ep_ret, ep_len = self.env.reset(), 0, 0
        episode = 0
        for timestep in tqdm(range(timesteps)):
            # Until start_steps have elapsed, sample random actions from environment
            # to encourage more exploration, sample from policy network after that
            if timestep<=self.start_steps:
                action = self.env.action_space.sample()
            else:
                action = self.get_action(state, self.act_noise)

            # step the environment
            next_state, reward, done, _ = self.env.step(action)
            ep_ret += reward
            ep_len += 1

            # ignore the 'done' signal if it just times out after timestep>max_timesteps
            done = False if ep_len==self.max_ep_len else done

            # store experience to replay buffer
            self.replay_buffer.append(state, action, reward, next_state, done)

            # Critical step to update current state
            state = next_state
            
            # Update handling
            if timestep>=self.update_after and (timestep+1)%self.update_every==0:
                for _ in range(self.update_every):
                    experiences = self.replay_buffer.sample(self.batch_size)
                    self.update(experiences, timestep)
            
            # End of trajectory/episode handling
            if done or (ep_len==self.max_ep_len):
                self.logger.store(EpRet=ep_ret, EpLen=ep_len)
                self.tensorboard_logger.add_scalar('episodic_return_train', ep_ret, timestep)
                self.tensorboard_logger.flush()
                # print(f"Episode reward: {ep_ret} | Episode Length: {ep_len}")
                state, ep_ret, ep_len = self.env.reset(), 0, 0
                episode += 1
                # Retrieve training reward
                x, y = self.logger.load_results(["EpLen", "EpRet"])
                if len(x) > 0:
                    # Mean training reward over the last 50 episodes
                    mean_reward = np.mean(y[-50:])

                    # New best model
                    if mean_reward > self.best_mean_reward:
                        print("Num timesteps: {}".format(timestep))
                        print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(self.best_mean_reward, mean_reward))

                        self.best_mean_reward = mean_reward
                        self.save_weights(fname=f"best_{trial_num}.pth")
                    
                    if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold:
                        print("Solved Environment, stopping iteration...")
                        return

                # self.evaluate_agent()
                self.logger.dump()

            if self.save_freq > 0 and timestep % self.save_freq == 0:
                self.save_weights(fname=f"latest_{trial_num}.pth")
        

    def learn(self, timesteps, num_trials=1):
        '''
        Function to learn using DDPG.
        Args:
            timesteps (int): number of timesteps to train for
        '''
        self.env.training=True
        best_reward_trial = -np.inf
        for trial in range(num_trials):
            self.tensorboard_logger = SummaryWriter(log_dir=os.path.join(self.tensorboard_logdir, f'{trial+1}'))
            self.learn_one_trial(timesteps, trial+1)
            
            if self.best_mean_reward > best_reward_trial:
                best_reward_trial = self.best_mean_reward
                self.save_weights(best=True)

            self.logger.reset()
            self.reinit_network()
            print()
            print(f"Trial {trial+1}/{num_trials} complete")

    def test(self, timesteps=None, render=False, record=False):
        '''
        Test the agent in the environment
        Args:
            render (bool): If true, render the image out for user to see in real time
            record (bool): If true, save the recording into a .gif file at the end of episode
            timesteps (int): number of timesteps to run the environment for. Default None will run to completion
        Return:
            Ep_Ret (int): Total reward from the episode
            Ep_Len (int): Total length of the episode in terms of timesteps
        '''
        self.env.training=False
        if render:
            self.env.render('human')
        state, done, ep_ret, ep_len = self.env.reset(), False, 0, 0
        img = []
        if record:
            img.append(self.env.render('rgb_array'))

        if timesteps is not None:
            for i in range(timesteps):
                # Take deterministic action with 0 noise added
                state, reward, done, _ = self.env.step(self.get_action(state, 0))
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += reward
                ep_len += 1                
        else:
            while not (done or (ep_len==self.max_ep_len)):
                # Take deterministic action with 0 noise added
                state, reward, done, _ = self.env.step(self.get_action(state, 0))
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += reward
                ep_len += 1

        if record:
            imageio.mimsave(f'{os.path.join(self.save_dir, "recording.gif")}', [np.array(img) for i, img in enumerate(img) if i%2 == 0], fps=29)

        self.env.training=True
        return ep_ret, ep_len      
Beispiel #2
0
class DAC_PPO:
    '''
    DAC + PPO
    '''
    def __init__(self,
                 env_fn,
                 save_dir,
                 tensorboard_logdir=None,
                 optimizer_class=Adam,
                 weight_decay=0,
                 oc_kwargs=dict(),
                 logger_kwargs=dict(),
                 lr=1e-3,
                 optimization_epochs=5,
                 mini_batch_size=64,
                 ppo_ratio_clip=0.2,
                 gamma=0.99,
                 rollout_length=2048,
                 beta_weight=0,
                 entropy_weight=0.01,
                 gradient_clip=5,
                 gae_tau=0.95,
                 max_ep_len=2000,
                 save_freq=200,
                 seed=0,
                 **kwargs):

        self.seed = seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.lr = lr
        self.env_fn = env_fn
        self.env = env_fn()
        self.oc_kwargs = oc_kwargs
        self.network_fn = self.get_network_fn(self.oc_kwargs)
        self.network = self.network_fn().to(self.device)
        self.optimizer_class = optimizer_class
        self.weight_decay = weight_decay
        self.optimizer = optimizer_class(self.network.parameters(),
                                         self.lr,
                                         weight_decay=self.weight_decay)
        self.gamma = gamma
        self.rollout_length = rollout_length
        self.num_options = oc_kwargs['num_options']
        self.beta_weight = beta_weight
        self.entropy_weight = entropy_weight
        self.gradient_clip = gradient_clip
        self.max_ep_len = max_ep_len
        self.save_freq = save_freq

        self.save_dir = save_dir
        self.logger = Logger(**logger_kwargs)
        self.tensorboard_logdir = tensorboard_logdir
        # self.tensorboard_logger = SummaryWriter(log_dir=tensorboard_logdir)

        self.is_initial_states = to_tensor(np.ones((1))).byte().to(self.device)
        self.prev_options = to_tensor(np.zeros((1))).long().to(self.device)

        self.best_mean_reward = -np.inf

        self.optimization_epochs = optimization_epochs
        self.mini_batch_size = mini_batch_size
        self.ppo_ratio_clip = ppo_ratio_clip
        self.gae_tau = gae_tau
        self.use_gae = self.gae_tau > 0

    def get_network_fn(self, oc_kwargs):
        activation = nn.ReLU
        gate = F.relu
        obs_space = self.env.observation_space.shape
        hidden_units = oc_kwargs['hidden_sizes']
        act_dim = self.env.action_space.shape[0]
        self.continuous = True

        if len(obs_space) > 1:
            # image observations
            phi_body = VAE(load_path =oc_kwargs['vae_weights_path'], device=self.device) if oc_kwargs['model_type'].lower() == 'vae' \
                        else ConvBody(obs_space, oc_kwargs['conv_layer_sizes'], activation, batchnorm=False)
            state_dim = phi_body.latent_dim
        else:
            state_dim = obs_space[0]
            phi_body = DummyBody(state_dim)

        network_fn = lambda: OptionGaussianActorCriticNet(
            state_dim,
            act_dim,
            num_options=oc_kwargs['num_options'],
            phi_body=phi_body,
            critic_body=FCBody(state_dim, hidden_units=hidden_units, gate=gate
                               ),
            option_body_fn=lambda: FCBody(
                state_dim, hidden_units=hidden_units, gate=gate),
            device=self.device)

        return network_fn

    def save_weights(self, best=False, fname=None):
        '''
        save the pytorch model weights of ac and ac_targ
        as well as pickling the environment to preserve any env parameters like normalisation params
        Args:
            best(bool): if true, save it as best.pth
            fname(string): if specified, save it as <fname>
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"

        print('saving checkpoint...')
        checkpoint = {'oc': self.network.state_dict()}
        torch.save(checkpoint, os.path.join(self.save_dir, _fname))
        self.env.save(os.path.join(self.save_dir, "env.json"))
        print(f"checkpoint saved at {os.path.join(self.save_dir, _fname)}")

    def load_weights(self, best=True, fname=None):
        '''
        Load the model weights and replay buffer from self.save_dir
        Args:
            best (bool): If True, save from the weights file with the best mean episode reward
            load_buffer (bool): If True, load the replay buffer from the pickled file
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"
        checkpoint_path = os.path.join(self.save_dir, _fname)
        if os.path.isfile(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.network.load_state_dict(sanitise_state_dict(checkpoint['oc']))

            env_path = os.path.join(self.save_dir, "env.json")
            if os.path.isfile(env_path):
                self.env = self.env.load(env_path)
                print("Environment loaded")

            print('checkpoint loaded at {}'.format(checkpoint_path))
        else:
            raise OSError("Checkpoint file not found.")

    def reinit_network(self):
        self.seed += 1
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        self.best_mean_reward = -np.inf
        self.network = self.network_fn().to(self.device)
        self.optimizer = self.optimizer_class(self.network.parameters(),
                                              self.lr,
                                              weight_decay=self.weight_decay)

    def record_online_return(self, ep_ret, timestep, ep_len):
        self.tensorboard_logger.add_scalar('episodic_return_train', ep_ret,
                                           timestep)
        self.logger.store(EpRet=ep_ret, EpLen=ep_len)
        self.logger.dump()
        # print(f"episode return: {ep_ret}")

    def compute_pi_hat(self, prediction, prev_option, is_initial_states):
        inter_pi = prediction['inter_pi']
        mask = torch.zeros_like(inter_pi)
        mask[:, prev_option] = 1
        beta = prediction['beta']
        pi_hat = (1 - beta) * mask + beta * inter_pi
        is_initial_states = is_initial_states.view(-1, 1).expand(
            -1, inter_pi.size(1))
        pi_hat = torch.where(is_initial_states, inter_pi, pi_hat)
        return pi_hat

    def compute_pi_bar(self, options, action, mean, std):
        options = options.unsqueeze(-1).expand(-1, -1, mean.size(-1))
        mean = mean.gather(1, options).squeeze(1)
        std = std.gather(1, options).squeeze(1)
        dist = torch.distributions.Normal(mean, std)
        pi_bar = dist.log_prob(action).sum(-1).exp().unsqueeze(-1)
        return pi_bar

    def compute_log_pi_a(self, options, pi_hat, action, mean, std, mdp):
        if mdp == 'hat':
            return pi_hat.add(1e-5).log().gather(1, options)
        elif mdp == 'bar':
            pi_bar = self.compute_pi_bar(options, action, mean, std)
            return pi_bar.add(1e-5).log()
        else:
            raise NotImplementedError

    def compute_adv(self, storage, mdp):
        v = storage.__getattribute__('v_%s' % (mdp))
        adv = storage.__getattribute__('adv_%s' % (mdp))
        all_ret = storage.__getattribute__('ret_%s' % (mdp))

        ret = v[-1].detach()
        advantages = to_tensor(np.zeros((1))).to(self.device)
        for i in reversed(range(self.rollout_length)):
            ret = storage.r[i] + self.gamma * storage.m[i] * ret
            if not self.use_gae:
                advantages = ret - v[i].detach()
            else:
                td_error = storage.r[i] + self.gamma * storage.m[i] * v[
                    i + 1] - v[i]
                advantages = advantages * self.gae_tau * self.gamma * storage.m[
                    i] + td_error
            adv[i] = advantages.detach()
            all_ret[i] = ret.detach()

    def update(self, storage, mdp, timestep, freeze_v=False):
        states, actions, options, log_probs_old, returns, advantages, prev_options, inits, pi_hat, mean, std = \
            storage.cat(
                ['s', 'a', 'o', 'log_pi_%s' % (mdp), 'ret_%s' % (mdp), 'adv_%s' % (mdp), 'prev_o', 'init', 'pi_hat',
                 'mean', 'std'])
        actions = actions.detach()
        log_probs_old = log_probs_old.detach()
        pi_hat = pi_hat.detach()
        mean = mean.detach()
        std = std.detach()
        advantages = (advantages - advantages.mean()) / advantages.std()

        for _ in range(self.optimization_epochs):
            sampler = random_sample(np.arange(states.size(0)),
                                    self.mini_batch_size)
            for batch_indices in sampler:
                batch_indices = to_tensor(batch_indices).long()

                sampled_pi_hat = pi_hat[batch_indices]
                sampled_mean = mean[batch_indices]
                sampled_std = std[batch_indices]
                sampled_states = states[batch_indices]
                sampled_prev_o = prev_options[batch_indices]
                sampled_init = inits[batch_indices]

                sampled_options = options[batch_indices]
                sampled_actions = actions[batch_indices]
                sampled_log_probs_old = log_probs_old[batch_indices]
                sampled_returns = returns[batch_indices]
                sampled_advantages = advantages[batch_indices]

                prediction = self.network(sampled_states, unsqueeze=False)
                if mdp == 'hat':
                    cur_pi_hat = self.compute_pi_hat(prediction,
                                                     sampled_prev_o.view(-1),
                                                     sampled_init.view(-1))
                    entropy = -(cur_pi_hat *
                                cur_pi_hat.add(1e-5).log()).sum(-1).mean()
                    log_pi_a = self.compute_log_pi_a(sampled_options,
                                                     cur_pi_hat,
                                                     sampled_actions,
                                                     sampled_mean, sampled_std,
                                                     mdp)
                    beta_loss = prediction['beta'].mean()
                elif mdp == 'bar':
                    log_pi_a = self.compute_log_pi_a(sampled_options,
                                                     sampled_pi_hat,
                                                     sampled_actions,
                                                     prediction['mean'],
                                                     prediction['std'], mdp)
                    entropy = 0
                    beta_loss = 0
                else:
                    raise NotImplementedError

                if mdp == 'bar':
                    v = prediction['q_o'].gather(1, sampled_options)
                elif mdp == 'hat':
                    v = (prediction['q_o'] *
                         sampled_pi_hat).sum(-1).unsqueeze(-1)
                else:
                    raise NotImplementedError

                ratio = (log_pi_a - sampled_log_probs_old).exp()
                obj = ratio * sampled_advantages
                obj_clipped = ratio.clamp(
                    1.0 - self.ppo_ratio_clip,
                    1.0 + self.ppo_ratio_clip) * sampled_advantages
                policy_loss = -torch.min(obj, obj_clipped).mean() - self.entropy_weight * entropy + \
                              self.beta_weight * beta_loss

                # discarded = (obj > obj_clipped).float().mean()
                value_loss = 0.5 * (sampled_returns - v).pow(2).mean()

                self.tensorboard_logger.add_scalar(f"loss/{mdp}_value_loss",
                                                   value_loss.item(), timestep)
                self.tensorboard_logger.add_scalar(f"loss/{mdp}_policy_loss",
                                                   policy_loss.item(),
                                                   timestep)
                self.tensorboard_logger.add_scalar(
                    f"loss/{mdp}_beta_loss", beta_loss if isinstance(
                        beta_loss, int) else beta_loss.item(), timestep)

                if freeze_v:
                    value_loss = 0

                self.optimizer.zero_grad()
                (policy_loss + value_loss).backward()
                nn.utils.clip_grad_norm_(self.network.parameters(),
                                         self.gradient_clip)
                self.optimizer.step()

    def learn_one_trial(self, num_timesteps, trial_num=1):
        self.states, ep_ret, ep_len = self.env.reset(), 0, 0
        storage = Storage(self.rollout_length,
                          ['adv_bar', 'adv_hat', 'ret_bar', 'ret_hat'])
        states = self.states
        for timestep in tqdm(range(1, num_timesteps + 1)):
            prediction = self.network(states)
            pi_hat = self.compute_pi_hat(prediction, self.prev_options,
                                         self.is_initial_states)
            dist = torch.distributions.Categorical(probs=pi_hat)
            options = dist.sample()

            # Gaussian policy
            mean = prediction['mean'][0, options]
            std = prediction['std'][0, options]
            dist = torch.distributions.Normal(mean, std)

            # select action
            actions = dist.sample()

            pi_bar = self.compute_pi_bar(options.unsqueeze(-1), actions,
                                         prediction['mean'], prediction['std'])

            v_bar = prediction['q_o'].gather(1, options.unsqueeze(-1))
            v_hat = (prediction['q_o'] * pi_hat).sum(-1).unsqueeze(-1)

            next_states, rewards, terminals, _ = self.env.step(
                to_np(actions[0]))
            ep_ret += rewards
            ep_len += 1

            # end of episode handling
            if terminals or ep_len == self.max_ep_len:
                next_states = self.env.reset()
                self.record_online_return(ep_ret, timestep, ep_len)
                ep_ret, ep_len = 0, 0
                # Retrieve training reward
                x, y = self.logger.load_results(["EpLen", "EpRet"])
                if len(x) > 0:
                    # Mean training reward over the last 50 episodes
                    mean_reward = np.mean(y[-50:])
                    # New best model
                    if mean_reward > self.best_mean_reward:
                        print("Num timesteps: {}".format(timestep))
                        print(
                            "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}"
                            .format(self.best_mean_reward, mean_reward))

                        self.best_mean_reward = mean_reward
                        self.save_weights(fname=f"best_{trial_num}.pth")

                    if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold:
                        print("Solved Environment, stopping iteration...")
                        return

            storage.add(prediction)
            storage.add({
                'r':
                to_tensor(rewards).to(self.device).unsqueeze(-1),
                'm':
                to_tensor(1 - terminals).to(self.device).unsqueeze(-1),
                'a':
                actions,
                'o':
                options.unsqueeze(-1),
                'prev_o':
                self.prev_options.unsqueeze(-1),
                's':
                to_tensor(states).unsqueeze(0),
                'init':
                self.is_initial_states.unsqueeze(-1),
                'pi_hat':
                pi_hat,
                'log_pi_hat':
                pi_hat[0, options].add(1e-5).log().unsqueeze(-1),
                'log_pi_bar':
                pi_bar.add(1e-5).log(),
                'v_bar':
                v_bar,
                'v_hat':
                v_hat
            })

            self.is_initial_states = to_tensor(terminals).unsqueeze(-1).to(
                self.device).byte()
            self.prev_options = options
            states = next_states

            if timestep % self.rollout_length == 0:
                self.states = states
                prediction = self.network(states)
                pi_hat = self.compute_pi_hat(prediction, self.prev_options,
                                             self.is_initial_states)
                dist = torch.distributions.Categorical(pi_hat)
                options = dist.sample()
                v_bar = prediction['q_o'].gather(1, options.unsqueeze(-1))
                v_hat = (prediction['q_o'] * pi_hat).sum(-1).unsqueeze(-1)

                storage.add(prediction)
                storage.add({
                    'v_bar': v_bar,
                    'v_hat': v_hat,
                })
                storage.placeholder()

                self.compute_adv(storage, 'bar')
                self.compute_adv(storage, 'hat')
                mdps = ['hat', 'bar']
                np.random.shuffle(mdps)
                self.update(storage, mdps[0], timestep)
                self.update(storage, mdps[1], timestep)

                storage = Storage(self.rollout_length,
                                  ['adv_bar', 'adv_hat', 'ret_bar', 'ret_hat'])

            if self.save_freq > 0 and timestep % self.save_freq == 0:
                self.save_weights(fname=f"latest_{trial_num}.pth")

    def learn(self, timesteps, num_trials=1):
        '''
        Function to learn using DDPG.
        Args:
            timesteps (int): number of timesteps to train for
        '''
        self.env.training = True
        self.network.train()
        best_reward_trial = -np.inf
        for trial in range(num_trials):
            self.tensorboard_logger = SummaryWriter(
                log_dir=os.path.join(self.tensorboard_logdir, f'{trial+1}'))
            self.learn_one_trial(timesteps, trial + 1)

            if self.best_mean_reward > best_reward_trial:
                best_reward_trial = self.best_mean_reward
                self.save_weights(best=True)

            self.logger.reset()
            self.reinit_network()
            print()
            print(f"Trial {trial+1}/{num_trials} complete")

    def test(self, timesteps=None, render=False, record=False):
        '''
        Test the agent in the environment
        Args:
            render (bool): If true, render the image out for user to see in real time
            record (bool): If true, save the recording into a .gif file at the end of episode
            timesteps (int): number of timesteps to run the environment for. Default None will run to completion
        Return:
            Ep_Ret (int): Total reward from the episode
            Ep_Len (int): Total length of the episode in terms of timesteps
        '''
        self.env.training = False
        self.network.eval()
        if render:
            self.env.render('human')
        states, terminals, ep_ret, ep_len = self.env.reset(), False, 0, 0
        is_initial_states = to_tensor(np.ones((1))).byte().to(self.device)
        prev_options = to_tensor(np.zeros((1))).long().to(self.device)
        img = []
        if record:
            img.append(self.env.render('rgb_array'))

        if timesteps is not None:
            for i in range(timesteps):
                prediction = self.network(states)
                pi_hat = self.compute_pi_hat(prediction, prev_options,
                                             is_initial_states)
                dist = torch.distributions.Categorical(probs=pi_hat)
                options = dist.sample()

                # Gaussian policy
                mean = prediction['mean'][0, options]
                std = prediction['std'][0, options]
                dist = torch.distributions.Normal(mean, std)

                # select action
                actions = mean

                next_states, rewards, terminals, _ = self.env.step(
                    to_np(actions[0]))
                is_initial_states = to_tensor(terminals).unsqueeze(
                    -1).byte().to(self.device)
                prev_options = options
                states = next_states
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += rewards
                ep_len += 1
        else:
            while not (terminals or (ep_len == self.max_ep_len)):
                # select option
                prediction = self.network(states)
                pi_hat = self.compute_pi_hat(prediction, prev_options,
                                             is_initial_states)
                dist = torch.distributions.Categorical(probs=pi_hat)
                options = dist.sample()

                # Gaussian policy
                mean = prediction['mean'][0, options]
                std = prediction['std'][0, options]
                # dist = torch.distributions.Normal(mean, std)

                # select action
                actions = mean

                next_states, rewards, terminals, _ = self.env.step(
                    to_np(actions[0]))
                is_initial_states = to_tensor(terminals).unsqueeze(
                    -1).byte().to(self.device)
                prev_options = options
                states = next_states
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()

                ep_ret += rewards
                ep_len += 1

        if record:
            imageio.mimsave(
                f'{os.path.join(self.save_dir, "recording.gif")}',
                [np.array(img) for i, img in enumerate(img) if i % 2 == 0],
                fps=29)

        self.env.training = True
        return ep_ret, ep_len
Beispiel #3
0
class TRPO:
    def __init__(self,
                 env_fn,
                 save_dir,
                 ac_kwargs=dict(),
                 seed=0,
                 tensorboard_logdir=None,
                 steps_per_epoch=400,
                 batch_size=400,
                 gamma=0.99,
                 delta=0.01,
                 vf_lr=1e-3,
                 train_v_iters=80,
                 damping_coeff=0.1,
                 cg_iters=10,
                 backtrack_iters=10,
                 backtrack_coeff=0.8,
                 lam=0.97,
                 max_ep_len=1000,
                 logger_kwargs=dict(),
                 save_freq=10,
                 algo='trpo',
                 ngpu=1):
        """
        Trust Region Policy Optimization 
        (with support for Natural Policy Gradient)
        Args:
            env_fn : A function which creates a copy of the environment.
                The environment must satisfy the OpenAI Gym API.
            save_dir: path to save directory
            actor_critic: Class for the actor-critic pytorch module
            ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
                function you provided to TRPO.
            seed (int): Seed for random number generators.
            steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
                for the agent and the environment in each epoch.
            batch_size (int): The buffer is split into batches of batch_size to learn from
            gamma (float): Discount factor. (Always between 0 and 1.)
            delta (float): KL-divergence limit for TRPO / NPG update. 
                (Should be small for stability. Values like 0.01, 0.05.)
            vf_lr (float): Learning rate for value function optimizer.
            train_v_iters (int): Number of gradient descent steps to take on 
                value function per epoch.
            damping_coeff (float): Artifact for numerical stability, should be 
                smallish. Adjusts Hessian-vector product calculation:
                
                .. math:: Hv \\rightarrow (\\alpha I + H)v
                where :math:`\\alpha` is the damping coefficient. 
                Probably don't play with this hyperparameter.
            cg_iters (int): Number of iterations of conjugate gradient to perform. 
                Increasing this will lead to a more accurate approximation
                to :math:`H^{-1} g`, and possibly slightly-improved performance,
                but at the cost of slowing things down. 
                Also probably don't play with this hyperparameter.
            backtrack_iters (int): Maximum number of steps allowed in the 
                backtracking line search. Since the line search usually doesn't 
                backtrack, and usually only steps back once when it does, this
                hyperparameter doesn't often matter.
            backtrack_coeff (float): How far back to step during backtracking line
                search. (Always between 0 and 1, usually above 0.5.)
            lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
                close to 1.)
            max_ep_len (int): Maximum length of trajectory / episode / rollout.
            logger_kwargs (dict): Keyword args for Logger. 
                            (1) output_dir = None
                            (2) output_fname = 'progress.pickle'
            save_freq (int): How often (in terms of gap between epochs) to save
                the current policy and value function.
            algo: Either 'trpo' or 'npg': this code supports both, since they are 
                almost the same.
        """
        # logger stuff
        self.logger = Logger(**logger_kwargs)

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.env = env_fn()
        self.vf_lr = vf_lr
        self.steps_per_epoch = steps_per_epoch  # if steps_per_epoch > self.env.spec.max_episode_steps else self.env.spec.max_episode_steps
        self.max_ep_len = max_ep_len
        self.train_v_iters = train_v_iters

        # Main network
        self.ngpu = ngpu
        self.actor_critic = get_actor_critic_module(ac_kwargs, 'trpo')
        self.ac_kwargs = ac_kwargs
        self.ac = self.actor_critic(self.env.observation_space,
                                    self.env.action_space,
                                    device=self.device,
                                    ngpu=self.ngpu,
                                    **ac_kwargs)

        # Create Optimizers
        self.v_optimizer = optim.Adam(self.ac.v.parameters(), lr=self.vf_lr)

        # GAE buffer
        self.gamma = gamma
        self.lam = lam
        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.shape
        self.buffer = GAEBuffer(self.obs_dim, self.act_dim,
                                self.steps_per_epoch, self.device, self.gamma,
                                self.lam)
        self.batch_size = batch_size

        self.cg_iters = cg_iters
        self.damping_coeff = damping_coeff
        self.delta = delta
        self.backtrack_coeff = backtrack_coeff
        self.algo = algo
        self.backtrack_iters = backtrack_iters
        self.best_mean_reward = -np.inf
        self.save_dir = save_dir
        self.save_freq = save_freq

        self.tensorboard_logdir = tensorboard_logdir

    def reinit_network(self):
        '''
        Re-initialize network weights and optimizers for a fresh agent to train
        '''
        # Main network
        self.best_mean_reward = -np.inf
        self.ac = self.actor_critic(self.env.observation_space,
                                    self.env.action_space,
                                    device=self.device,
                                    ngpu=self.ngpu,
                                    **self.ac_kwargs)

        # Create Optimizers
        self.v_optimizer = optim.Adam(self.ac.v.parameters(), lr=self.vf_lr)
        self.buffer = GAEBuffer(self.obs_dim, self.act_dim,
                                self.steps_per_epoch, self.device, self.gamma,
                                self.lam)

    def flat_grad(self, grads, hessian=False):
        grad_flatten = []
        if hessian == False:
            for grad in grads:
                grad_flatten.append(grad.view(-1))
            grad_flatten = torch.cat(grad_flatten)
            return grad_flatten
        elif hessian == True:
            for grad in grads:
                grad_flatten.append(grad.contiguous().view(-1))
            grad_flatten = torch.cat(grad_flatten).data
            return grad_flatten

    def cg(self, obs, b, EPS=1e-8, residual_tol=1e-10):
        # Conjugate gradient algorithm
        # (https://en.wikipedia.org/wiki/Conjugate_gradient_method)
        x = torch.zeros(b.size()).to(self.device)
        r = b.clone()
        p = r.clone()
        rdotr = torch.dot(r, r).to(self.device)

        for _ in range(self.cg_iters):
            Ap = self.hessian_vector_product(obs, p)
            alpha = rdotr / (torch.dot(p, Ap).to(self.device) + EPS)

            x += alpha * p
            r -= alpha * Ap

            new_rdotr = torch.dot(r, r)
            p = r + (new_rdotr / rdotr) * p
            rdotr = new_rdotr

            if rdotr < residual_tol:
                break

        return x

    def hessian_vector_product(self, obs, p):
        p = p.detach()
        kl = self.ac.pi.calculate_kl(old_policy=self.ac.pi,
                                     new_policy=self.ac.pi,
                                     obs=obs)
        kl_grad = torch.autograd.grad(kl,
                                      self.ac.pi.parameters(),
                                      create_graph=True)
        kl_grad = self.flat_grad(kl_grad)

        kl_grad_p = (kl_grad * p).sum()
        kl_hessian = torch.autograd.grad(kl_grad_p, self.ac.pi.parameters())
        kl_hessian = self.flat_grad(kl_hessian, hessian=True)
        return kl_hessian + p * self.damping_coeff

    def flat_params(self, model):
        params = []
        for param in model.parameters():
            params.append(param.data.view(-1))
        params_flatten = torch.cat(params)
        return params_flatten

    def update_model(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index:index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param)
            index += params_length

    def update(self):
        self.ac.train()
        data = self.buffer.get()
        obs_ = data['obs']
        act_ = data['act']
        ret_ = data['ret']
        adv_ = data['adv']
        logp_old_ = data['logp']

        for index in BatchSampler(
                SubsetRandomSampler(range(self.steps_per_epoch)),
                self.batch_size, False):
            obs = obs_[index]
            act = act_[index]
            ret = ret_[index]
            adv = adv_[index]
            logp_old = logp_old_[index]

            # Prediction logπ_old(s), logπ(s)
            _, logp = self.ac.pi(obs, act)

            # Policy loss
            ratio_old = torch.exp(logp - logp_old)
            surrogate_adv_old = (ratio_old * adv).mean()

            # policy gradient calculation as per algorithm, flatten to do matrix calculations later
            gradient = torch.autograd.grad(
                surrogate_adv_old, self.ac.pi.parameters()
            )  # calculate gradient of policy loss w.r.t to policy parameters
            gradient = self.flat_grad(gradient)

            # Core calculations for NPG/TRPO
            search_dir = self.cg(obs, gradient.data)  # H^-1 g
            gHg = (self.hessian_vector_product(obs, search_dir) *
                   search_dir).sum(0)
            step_size = torch.sqrt(2 * self.delta / gHg)
            old_params = self.flat_params(self.ac.pi)
            # update the old model, calculate KL divergence then decide whether to update new model
            self.update_model(self.ac.pi_old, old_params)

            if self.algo == 'npg':
                params = old_params + step_size * search_dir
                self.update_model(self.ac.pi, params)

                kl = self.ac.pi.calculate_kl(new_policy=self.ac.pi,
                                             old_policy=self.ac.pi_old,
                                             obs=obs)
            elif self.algo == 'trpo':
                for i in range(self.backtrack_iters):
                    # Backtracking line search
                    # (https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf) 464p.
                    params = old_params + (self.backtrack_coeff**
                                           (i + 1)) * step_size * search_dir
                    self.update_model(self.ac.pi, params)

                    # Prediction logπ_old(s), logπ(s)
                    _, logp = self.ac.pi(obs, act)

                    # Policy loss
                    ratio = torch.exp(logp - logp_old)
                    surrogate_adv = (ratio * adv).mean()

                    improve = surrogate_adv - surrogate_adv_old
                    kl = self.ac.pi.calculate_kl(new_policy=self.ac.pi,
                                                 old_policy=self.ac.pi_old,
                                                 obs=obs)

                    # print(f"kl: {kl}")
                    if kl <= self.delta and improve > 0:
                        print(
                            'Accepting new params at step %d of line search.' %
                            i)
                        # self.backtrack_iters.append(i)
                        # log backtrack_iters=i
                        break

                    if i == self.backtrack_iters - 1:
                        print('Line search failed! Keeping old params.')
                        # self.backtrack_iters.append(i)
                        # log backtrack_iters=i

                        params = self.flat_params(self.ac.pi_old)
                        self.update_model(self.ac.pi, params)

            # Update Critic
            for _ in range(self.train_v_iters):
                self.v_optimizer.zero_grad()
                v = self.ac.v(obs)
                v_loss = ((v - ret)**2).mean()
                v_loss.backward()
                self.v_optimizer.step()

    def save_weights(self, best=False, fname=None):
        '''
        save the pytorch model weights of critic and actor networks
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"

        print('saving checkpoint...')
        checkpoint = {
            'v': self.ac.v.state_dict(),
            'pi': self.ac.pi.state_dict(),
            'v_optimizer': self.v_optimizer.state_dict()
        }
        torch.save(checkpoint, os.path.join(self.save_dir, _fname))
        self.env.save(os.path.join(self.save_dir, "env.json"))
        print(f"checkpoint saved at {os.path.join(self.save_dir, _fname)}")

    def load_weights(self, best=True):
        '''
        Load the model weights and replay buffer from self.save_dir
        Args:
            best (bool): If True, save from the weights file with the best mean episode reward
        '''
        if best:
            fname = "best.pth"
        else:
            fname = "model_weights.pth"
        checkpoint_path = os.path.join(self.save_dir, fname)
        if os.path.isfile(checkpoint_path):
            key = 'cuda' if torch.cuda.is_available() else 'cpu'
            checkpoint = torch.load(checkpoint_path, map_location=key)
            self.ac.v.load_state_dict(
                sanitise_state_dict(checkpoint['v'], self.ngpu > 1))
            self.ac.pi.load_state_dict(
                sanitise_state_dict(checkpoint['pi'], self.ngpu > 1))
            self.v_optimizer.load_state_dict(
                sanitise_state_dict(checkpoint['v_optimizer'], self.ngpu > 1))

            env_path = os.path.join(self.save_dir, "env.json")
            if os.path.isfile(env_path):
                self.env = self.env.load(env_path)
                print("Environment loaded")

            print('checkpoint loaded at {}'.format(checkpoint_path))
        else:
            raise OSError("Checkpoint file not found.")

    def learn_one_trial(self, timesteps, trial_num):
        ep_rets = []
        epochs = int((timesteps / self.steps_per_epoch) + 0.5)
        print(
            "Rounded off to {} epochs with {} steps per epoch, total {} timesteps"
            .format(epochs, self.steps_per_epoch,
                    epochs * self.steps_per_epoch))
        start_time = time.time()
        obs, ep_ret, ep_len = self.env.reset(), 0, 0
        ep_num = 0
        for epoch in tqdm(range(epochs)):
            for t in range(self.steps_per_epoch):
                # step the environment
                a, v, logp = self.ac.step(
                    torch.as_tensor(obs, dtype=torch.float32).to(self.device))
                next_obs, reward, done, _ = self.env.step(a)
                ep_ret += reward
                ep_len += 1

                # Add experience to buffer
                self.buffer.store(obs, a, reward, v, logp)

                obs = next_obs
                timeout = ep_len == self.max_ep_len
                terminal = done or timeout
                epoch_ended = t == self.steps_per_epoch - 1

                # End of trajectory/episode handling
                if terminal or epoch_ended:
                    if timeout or epoch_ended:
                        _, v, _ = self.ac.step(
                            torch.as_tensor(obs, dtype=torch.float32).to(
                                self.device))
                    else:
                        v = 0

                    ep_num += 1
                    self.logger.store(EpRet=ep_ret, EpLen=ep_len)
                    self.tensorboard_logger.add_scalar(
                        'episodic_return_train', ep_ret,
                        epoch * self.steps_per_epoch + (t + 1))
                    self.buffer.finish_path(v)
                    obs, ep_ret, ep_len = self.env.reset(), 0, 0
                    # Retrieve training reward
                    x, y = self.logger.load_results(["EpLen", "EpRet"])
                    if len(x) > 0:
                        # Mean training reward over the last 50 episodes
                        mean_reward = np.mean(y[-50:])

                        # New best model
                        if mean_reward > self.best_mean_reward:
                            # print("Num timesteps: {}".format(timestep))
                            print(
                                "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}"
                                .format(self.best_mean_reward, mean_reward))

                            self.best_mean_reward = mean_reward
                            self.save_weights(fname=f"best_{trial_num}.pth")

                        if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold:
                            print("Solved Environment, stopping iteration...")
                            return

            # update value function and TRPO policy update
            self.update()
            self.logger.dump()
            if self.save_freq > 0 and epoch % self.save_freq == 0:
                self.save_weights(fname=f"latest_{trial_num}.pth")

    def learn(self, timesteps, num_trials=1):
        '''
        Function to learn using TRPO.
        Args:
            timesteps (int): number of timesteps to train for
            num_trials (int): Number of times to train the agent
        '''
        self.env.training = True
        best_reward_trial = -np.inf
        for trial in range(num_trials):
            self.tensorboard_logger = SummaryWriter(
                log_dir=os.path.join(self.tensorboard_logdir, f'{trial+1}'))
            self.learn_one_trial(timesteps, trial + 1)

            if self.best_mean_reward > best_reward_trial:
                best_reward_trial = self.best_mean_reward
                self.save_weights(best=True)

            self.logger.reset()
            self.reinit_network()
            print()
            print(f"Trial {trial+1}/{num_trials} complete")

    def test(self, timesteps=None, render=False, record=False):
        '''
        Test the agent in the environment
        Args:
            render (bool): If true, render the image out for user to see in real time
            record (bool): If true, save the recording into a .gif file at the end of episode
            timesteps (int): number of timesteps to run the environment for. Default None will run to completion
        Return:
            Ep_Ret (int): Total reward from the episode
            Ep_Len (int): Total length of the episode in terms of timesteps
        '''
        self.env.training = False
        if render:
            self.env.render('human')
        obs, done, ep_ret, ep_len = self.env.reset(), False, 0, 0
        img = []
        if record:
            img.append(self.env.render('rgb_array'))

        if timesteps is not None:
            for i in range(timesteps):
                # Take stochastic action with policy network
                action, _, _ = self.ac.step(
                    torch.as_tensor(obs, dtype=torch.float32).to(self.device))
                obs, reward, done, _ = self.env.step(action)
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += reward
                ep_len += 1
        else:
            while not (done or (ep_len == self.max_ep_len)):
                # Take stochastic action with policy network
                action, _, _ = self.ac.step(
                    torch.as_tensor(obs, dtype=torch.float32).to(self.device))
                obs, reward, done, _ = self.env.step(action)
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += reward
                ep_len += 1

        self.env.training = True
        if record:
            imageio.mimsave(
                f'{os.path.join(self.save_dir, "recording.gif")}',
                [np.array(img) for i, img in enumerate(img) if i % 2 == 0],
                fps=29)

        return ep_ret, ep_len
class Option_Critic:
    def __init__(self,
                 env_fn,
                 save_dir,
                 tensorboard_logdir=None,
                 optimizer_class=RMSprop,
                 oc_kwargs=dict(),
                 logger_kwargs=dict(),
                 eps_start=1.0,
                 eps_end=0.1,
                 eps_decay=1e4,
                 lr=1e-3,
                 gamma=0.99,
                 rollout_length=2048,
                 beta_reg=0.01,
                 entropy_weight=0.01,
                 gradient_clip=5,
                 target_network_update_freq=200,
                 max_ep_len=2000,
                 save_freq=200,
                 seed=0,
                 **kwargs):

        self.seed = seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.lr = lr
        self.env_fn = env_fn
        self.env = env_fn()
        self.oc_kwargs = oc_kwargs
        self.network_fn = self.get_network_fn(self.oc_kwargs)
        self.network = self.network_fn().to(self.device)
        self.target_network = self.network_fn().to(self.device)
        self.optimizer_class = optimizer_class
        self.optimizer = optimizer_class(self.network.parameters(), self.lr)
        self.target_network.load_state_dict(self.network.state_dict())
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.eps_schedule = LinearSchedule(eps_start, eps_end, eps_decay)
        self.gamma = gamma
        self.rollout_length = rollout_length
        self.num_options = oc_kwargs['num_options']
        self.beta_reg = beta_reg
        self.entropy_weight = entropy_weight
        self.gradient_clip = gradient_clip
        self.target_network_update_freq = target_network_update_freq
        self.max_ep_len = max_ep_len
        self.save_freq = save_freq

        self.save_dir = save_dir
        self.logger = Logger(**logger_kwargs)
        self.tensorboard_logdir = tensorboard_logdir
        # self.tensorboard_logger = SummaryWriter(log_dir=tensorboard_logdir)

        self.is_initial_states = to_tensor(np.ones((1))).byte()
        self.prev_options = self.is_initial_states.clone().long().to(
            self.device)

        self.best_mean_reward = -np.inf

    def get_network_fn(self, oc_kwargs):
        activation = nn.ReLU
        gate = F.relu
        obs_space = self.env.observation_space.shape
        hidden_units = oc_kwargs['hidden_sizes']
        act_dim = self.env.action_space.shape[0]
        self.continuous = True

        if len(obs_space) > 1:
            # image observations
            phi_body = VAE(load_path =oc_kwargs['vae_weights_path'], device=self.device) if oc_kwargs['model_type'].lower() == 'vae' \
                        else ConvBody(obs_space, oc_kwargs['conv_layer_sizes'], activation, batchnorm=True)
            state_dim = phi_body.latent_dim
        else:
            state_dim = obs_space[0]
            phi_body = FCBody(state_dim, hidden_units=hidden_units, gate=gate)

        network_fn = lambda: OptionCriticNet(body=phi_body,
                                             action_dim=act_dim,
                                             num_options=oc_kwargs[
                                                 'num_options'],
                                             device=self.device)

        return network_fn

    def update(self, storage, states, timestep):
        with torch.no_grad():
            prediction = self.target_network(states)
            storage.placeholder(
            )  # create the beta_adv attribute inside storage to be [None]*rollout_length
            betas = prediction['beta'].squeeze()[self.prev_options]
            ret = (1 - betas) * prediction['q'][self.worker_index, self.prev_options] + \
                  betas * torch.max(prediction['q'], dim=-1)[0]
            ret = ret.unsqueeze(-1)

        for i in reversed(range(self.rollout_length)):
            ret = storage.r[i] + self.gamma * storage.m[i] * ret
            adv = ret - storage.q[i].gather(1, storage.o[i])
            storage.ret[i] = ret
            storage.adv[i] = adv

            v = storage.q[i].max(dim=-1, keepdim=True)[0] * (
                1 - storage.eps[i]
            ) + storage.q[i].mean(-1).unsqueeze(-1) * storage.eps[i]
            q = storage.q[i].gather(1, storage.prev_o[i])
            storage.beta_adv[i] = q - v + self.beta_reg

        q, beta, log_pi, ret, adv, beta_adv, ent, option, action, initial_states, prev_o = \
            storage.cat(['q', 'beta', 'log_pi', 'ret', 'adv', 'beta_adv', 'ent', 'o', 'a', 'init', 'prev_o'])

        # calculate loss function
        q_loss = (q.gather(1, option) - ret.detach()).pow(2).mul(0.5).mean()
        pi_loss = -(log_pi.gather(1, action) *
                    adv.detach()) - self.entropy_weight * ent
        pi_loss = pi_loss.mean()
        beta_loss = (beta.gather(1, prev_o) * beta_adv.detach() *
                     (1 - initial_states)).mean()
        # logging all losses
        self.logger.store(q_loss=q_loss.item(),
                          pi_loss=pi_loss.item(),
                          beta_loss=beta_loss.item())
        self.tensorboard_logger.add_scalar("loss/q_loss", q_loss.item(),
                                           timestep)
        self.tensorboard_logger.add_scalar("loss/pi_loss", pi_loss.item(),
                                           timestep)
        self.tensorboard_logger.add_scalar("loss/beta_loss", beta_loss.item(),
                                           timestep)

        # backward and train
        self.optimizer.zero_grad()
        (pi_loss + q_loss + beta_loss).backward()
        nn.utils.clip_grad_norm_(self.network.parameters(), self.gradient_clip)
        self.optimizer.step()

    def save_weights(self, best=False, fname=None):
        '''
        save the pytorch model weights of ac and ac_targ
        as well as pickling the environment to preserve any env parameters like normalisation params
        Args:
            best(bool): if true, save it as best.pth
            fname(string): if specified, save it as <fname>
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"

        print('saving checkpoint...')
        checkpoint = {
            'oc': self.network.state_dict(),
            'oc_target': self.target_network.state_dict()
        }
        torch.save(checkpoint, os.path.join(self.save_dir, _fname))
        self.env.save(os.path.join(self.save_dir, "env.json"))
        print(f"checkpoint saved at {os.path.join(self.save_dir, _fname)}")

    def load_weights(self, best=True, fname=None):
        '''
        Load the model weights and replay buffer from self.save_dir
        Args:
            best (bool): If True, save from the weights file with the best mean episode reward
            load_buffer (bool): If True, load the replay buffer from the pickled file
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"
        checkpoint_path = os.path.join(self.save_dir, _fname)
        if os.path.isfile(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.network.load_state_dict(sanitise_state_dict(checkpoint['oc']))
            self.target_network.load_state_dict(
                sanitise_state_dict(checkpoint['oc_target']))

            env_path = os.path.join(self.save_dir, "env.json")
            if os.path.isfile(env_path):
                self.env = self.env.load(env_path)
                print("Environment loaded")

            print('checkpoint loaded at {}'.format(checkpoint_path))
        else:
            raise OSError("Checkpoint file not found.")

    def reinit_network(self):
        self.seed += 1
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        self.network = self.network_fn().to(self.device)
        self.target_network = self.network_fn().to(self.device)
        self.optimizer = self.optimizer_class(self.network.parameters(),
                                              self.lr)
        self.target_network.load_state_dict(self.network.state_dict())
        self.eps_schedule = LinearSchedule(self.eps_start, self.eps_end,
                                           self.eps_decay)

    def sample_option(self, prediction, epsilon, prev_option,
                      is_intial_states):
        with torch.no_grad():
            # get q value
            q_option = prediction['q_o']
            pi_option = torch.zeros_like(q_option).add(epsilon /
                                                       q_option.size(1))

            # greedy policy
            greedy_option = q_option.argmax(dim=-1, keepdim=True)
            prob = 1 - epsilon + epsilon / q_option.size(1)
            prob = torch.zeros_like(pi_option).add(prob)
            pi_option.scatter_(1, greedy_option, prob)

            mask = torch.zeros_like(q_option)
            mask[:, prev_option] = 1
            beta = prediction['beta']
            pi_hat_option = (1 - beta) * mask + beta * pi_option

            dist = torch.distributions.Categorical(probs=pi_option)
            options = dist.sample()
            dist = torch.distributions.Categorical(probs=pi_hat_option)
            options_hat = dist.sample()

            options = torch.where(is_intial_states.to(self.device), options,
                                  options_hat)
        return options

    def record_online_return(self, ep_ret, timestep, ep_len):
        self.tensorboard_logger.add_scalar('episodic_return_train', ep_ret,
                                           timestep)
        self.logger.store(EpRet=ep_ret, EpLen=ep_len)
        self.logger.dump()
        # print(f"episode return: {ep_ret}")

    def learn_one_trial(self, num_timesteps, trial_num=1):
        self.states, ep_ret, ep_len = self.env.reset(), 0, 0
        storage = Storage(self.rollout_length,
                          ['beta', 'o', 'beta_adv', 'prev_o', 'init', 'eps'])
        for timestep in tqdm(range(1, num_timesteps + 1)):
            prediction = self.network(self.states)
            epsilon = self.eps_schedule()
            # select option
            options = self.sample_option(prediction, epsilon,
                                         self.prev_options,
                                         self.is_initial_states)
            prediction['pi'] = prediction['pi'][0, options]
            prediction['log_pi'] = prediction['log_pi'][0, options]
            dist = torch.distributions.Categorical(probs=prediction['pi'])
            actions = dist.sample()
            entropy = dist.entropy()

            next_states, rewards, terminals, _ = self.env.step(to_np(actions))
            ep_ret += rewards
            ep_len += 1

            # end of episode handling
            if terminals or ep_len == self.max_ep_len:
                next_states = self.env.reset()
                self.record_online_return(ep_ret, timestep, ep_len)
                ep_ret, ep_len = 0, 0
                # Retrieve training reward
                x, y = self.logger.load_results(["EpLen", "EpRet"])
                if len(x) > 0:
                    # Mean training reward over the last 50 episodes
                    mean_reward = np.mean(y[-50:])
                    # New best model
                    if mean_reward > self.best_mean_reward:
                        print("Num timesteps: {}".format(timestep))
                        print(
                            "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}"
                            .format(self.best_mean_reward, mean_reward))

                        self.best_mean_reward = mean_reward
                        self.save_weights(fname=f"best_{trial_num}.pth")

                    if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold:
                        print("Solved Environment, stopping iteration...")
                        return

            storage.add(prediction)
            storage.add({
                'r':
                to_tensor(rewards).to(self.device).unsqueeze(-1),
                'm':
                to_tensor(1 - terminals).to(self.device).unsqueeze(-1),
                'o':
                options.unsqueeze(-1),
                'prev_o':
                self.prev_options.unsqueeze(-1),
                'ent':
                entropy,
                'a':
                actions.unsqueeze(-1),
                'init':
                self.is_initial_states.unsqueeze(-1).to(self.device).float(),
                'eps':
                epsilon
            })

            self.is_initial_states = to_tensor(terminals).unsqueeze(-1).byte()
            self.prev_options = options
            self.states = next_states

            if timestep % self.target_network_update_freq == 0:
                self.target_network.load_state_dict(self.network.state_dict())

            if timestep % self.rollout_length == 0:
                self.update(storage, self.states, timestep)
                storage = Storage(
                    self.rollout_length,
                    ['beta', 'o', 'beta_adv', 'prev_o', 'init', 'eps'])

            if self.save_freq > 0 and timestep % self.save_freq == 0:
                self.save_weights(fname=f"latest_{trial_num}.pth")

    def learn(self, timesteps, num_trials=1):
        '''
        Function to learn using DDPG.
        Args:
            timesteps (int): number of timesteps to train for
        '''
        self.env.training = True
        self.network.train()
        self.target_network.train()
        best_reward_trial = -np.inf
        for trial in range(num_trials):
            self.tensorboard_logger = SummaryWriter(
                log_dir=os.path.join(self.tensorboard_logdir, f'{trial+1}'))
            self.learn_one_trial(timesteps, trial + 1)

            if self.best_mean_reward > best_reward_trial:
                best_reward_trial = self.best_mean_reward
                self.save_weights(best=True)

            self.logger.reset()
            self.reinit_network()
            print()
            print(f"Trial {trial+1}/{num_trials} complete")

    def test(self, timesteps=None, render=False, record=False):
        '''
        Test the agent in the environment
        Args:
            render (bool): If true, render the image out for user to see in real time
            record (bool): If true, save the recording into a .gif file at the end of episode
            timesteps (int): number of timesteps to run the environment for. Default None will run to completion
        Return:
            Ep_Ret (int): Total reward from the episode
            Ep_Len (int): Total length of the episode in terms of timesteps
        '''
        self.env.training = False
        self.network.eval()
        self.target_network.eval()
        if render:
            self.env.render('human')
        states, done, ep_ret, ep_len = self.env.reset(), False, 0, 0
        is_initial_states = to_tensor(np.ones((1))).byte().to(self.device)
        prev_options = is_initial_states.clone().long().to(self.device)
        prediction = self.network(states)
        epsilon = 0.0
        # select option
        options = self.sample_option(prediction, epsilon, prev_options,
                                     is_initial_states)
        img = []
        if record:
            img.append(self.env.render('rgb_array'))

        if timesteps is not None:
            for i in range(timesteps):
                # select option
                options = self.sample_option(prediction, epsilon, prev_options,
                                             is_initial_states)

                # Gaussian policy
                mean = prediction['mean'][0, options]
                std = prediction['std'][0, options]
                dist = torch.distributions.Normal(mean, std)

                # select action
                actions = dist.sample()

                next_states, rewards, terminals, _ = self.env.step(
                    to_np(actions[0]))
                is_initial_states = to_tensor(terminals).unsqueeze(-1).byte()
                prev_options = options
                states = next_states
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += rewards
                ep_len += 1
        else:
            while not (done or (ep_len == self.max_ep_len)):
                # select option
                options = self.sample_option(prediction, epsilon, prev_options,
                                             is_initial_states)

                # Gaussian policy
                mean = prediction['mean'][0, options]
                std = prediction['std'][0, options]
                dist = torch.distributions.Normal(mean, std)
                # select action
                actions = dist.sample()

                next_states, rewards, terminals, _ = self.env.step(
                    to_np(actions[0]))
                is_initial_states = to_tensor(terminals).unsqueeze(-1).byte()
                prev_options = options
                states = next_states
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()

                ep_ret += rewards
                ep_len += 1

        if record:
            imageio.mimsave(
                f'{os.path.join(self.save_dir, "recording.gif")}',
                [np.array(img) for i, img in enumerate(img) if i % 2 == 0],
                fps=29)

        self.env.training = True
        return ep_ret, ep_len
Beispiel #5
0
class PPO:
    def __init__(self,
                 env_fn,
                 save_dir,
                 ac_kwargs=dict(),
                 seed=0,
                 tensorboard_logdir=None,
                 steps_per_epoch=400,
                 batch_size=400,
                 gamma=0.99,
                 clip_ratio=0.2,
                 vf_lr=1e-3,
                 pi_lr=3e-4,
                 train_v_iters=80,
                 train_pi_iters=80,
                 lam=0.97,
                 max_ep_len=1000,
                 target_kl=0.01,
                 logger_kwargs=dict(),
                 save_freq=10,
                 ngpu=1):
        """
        Proximal Policy Optimization 
        Args:
            env_fn : A function which creates a copy of the environment.
                The environment must satisfy the OpenAI Gym API.
            save_dir: path to save directory
            actor_critic: Class for the actor-critic pytorch module
            ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
                function you provided to TRPO.
            seed (int): Seed for random number generators.
            steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
                for the agent and the environment in each epoch.
            batch_size (int): The buffer is split into batches of batch_size to learn from
            gamma (float): Discount factor. (Always between 0 and 1.)
            clip_ratio (float): Hyperparameter for clipping in the policy objective.
                Roughly: how far can the new policy go from the old policy while 
                still profiting (improving the objective function)? The new policy 
                can still go farther than the clip_ratio says, but it doesn't help
                on the objective anymore. (Usually small, 0.1 to 0.3.) Typically
                denoted by :math:`\epsilon`. 
            pi_lr (float): Learning rate for policy optimizer.
            vf_lr (float): Learning rate for value function optimizer.
            train_v_iters (int): Number of gradient descent steps to take on 
                value function per epoch.
            train_pi_iters (int): Maximum number of gradient descent steps to take 
                on policy loss per epoch. (Early stopping may cause optimizer
                to take fewer than this.)    
            lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
                close to 1.)
            max_ep_len (int): Maximum length of trajectory / episode / rollout.
            target_kl (float): Roughly what KL divergence we think is appropriate
                between new and old policies after an update. This will get used 
                for early stopping. (Usually small, 0.01 or 0.05.)
            logger_kwargs (dict): Keyword args for Logger. 
                            (1) output_dir = None
                            (2) output_fname = 'progress.pickle'
            save_freq (int): How often (in terms of gap between epochs) to save
                the current policy and value function.
        """
        # logger stuff
        self.logger = Logger(**logger_kwargs)

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.env = env_fn()
        self.vf_lr = vf_lr
        self.pi_lr = pi_lr
        self.steps_per_epoch = steps_per_epoch  # if steps_per_epoch > self.env.spec.max_episode_steps else self.env.spec.max_episode_steps

        self.max_ep_len = max_ep_len
        # self.max_ep_len = self.env.spec.max_episode_steps if self.env.spec.max_episode_steps is not None else max_ep_len
        self.train_v_iters = train_v_iters
        self.train_pi_iters = train_pi_iters

        # Main network
        self.ngpu = ngpu
        self.actor_critic = get_actor_critic_module(ac_kwargs, 'ppo')
        self.ac_kwargs = ac_kwargs
        self.ac = self.actor_critic(self.env.observation_space,
                                    self.env.action_space,
                                    device=self.device,
                                    ngpu=self.ngpu,
                                    **ac_kwargs)

        # Create Optimizers
        self.v_optimizer = optim.Adam(self.ac.v.parameters(), lr=self.vf_lr)
        self.pi_optimizer = optim.Adam(self.ac.pi.parameters(), lr=self.pi_lr)

        # GAE buffer
        self.gamma = gamma
        self.lam = lam
        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.shape
        self.buffer = GAEBuffer(self.obs_dim, self.act_dim,
                                self.steps_per_epoch, self.device, self.gamma,
                                self.lam)
        self.batch_size = batch_size

        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.best_mean_reward = -np.inf
        self.save_dir = save_dir
        self.save_freq = save_freq

        self.tensorboard_logdir = tensorboard_logdir

    def reinit_network(self):
        '''
        Re-initialize network weights and optimizers for a fresh agent to train
        '''

        # Main network
        self.best_mean_reward = -np.inf
        self.ac = self.actor_critic(self.env.observation_space,
                                    self.env.action_space,
                                    device=self.device,
                                    ngpu=self.ngpu,
                                    **self.ac_kwargs)

        # Create Optimizers
        self.v_optimizer = optim.Adam(self.ac.v.parameters(), lr=self.vf_lr)
        self.pi_optimizer = optim.Adam(self.ac.pi.parameters(), lr=self.pi_lr)

        self.buffer = GAEBuffer(self.obs_dim, self.act_dim,
                                self.steps_per_epoch, self.device, self.gamma,
                                self.lam)

    def update(self):
        self.ac.train()
        data = self.buffer.get()
        obs_ = data['obs']
        act_ = data['act']
        ret_ = data['ret']
        adv_ = data['adv']
        logp_old_ = data['logp']

        for index in BatchSampler(
                SubsetRandomSampler(range(self.steps_per_epoch)),
                self.batch_size, False):
            obs = obs_[index]
            act = act_[index]
            ret = ret_[index]
            adv = adv_[index]
            logp_old = logp_old_[index]

            # ---------------------Recording the losses before the updates --------------------------------
            pi, logp = self.ac.pi(obs, act)
            ratio = torch.exp(logp - logp_old)
            clipped_adv = torch.clamp(ratio, 1 - self.clip_ratio,
                                      1 + self.clip_ratio) * adv
            loss_pi = -torch.min(ratio * adv, clipped_adv).mean()
            v = self.ac.v(obs)
            loss_v = ((v - ret)**2).mean()

            self.logger.store(LossV=loss_v.item(), LossPi=loss_pi.item())
            # --------------------------------------------------------------------------------------------

            # Update Policy
            for i in range(self.train_pi_iters):
                # Policy loss
                self.pi_optimizer.zero_grad()
                pi, logp = self.ac.pi(obs, act)
                ratio = torch.exp(logp - logp_old)
                clipped_adv = torch.clamp(ratio, 1 - self.clip_ratio,
                                          1 + self.clip_ratio) * adv
                loss_pi = -torch.min(ratio * adv, clipped_adv).mean()
                approx_kl = (logp - logp_old).mean().item()
                if approx_kl > 1.5 * self.target_kl:
                    print(f"Early stopping at step {i} due to reaching max kl")
                    break
                loss_pi.backward()
                self.pi_optimizer.step()

            # Update Value Function
            for _ in range(self.train_v_iters):
                self.v_optimizer.zero_grad()
                v = self.ac.v(obs)
                loss_v = ((v - ret)**2).mean()
                loss_v.backward()
                self.v_optimizer.step()

    def save_weights(self, best=False, fname=None):
        '''
        save the pytorch model weights of critic and actor networks
        '''
        if fname is not None:
            _fname = fname
        elif best:
            _fname = "best.pth"
        else:
            _fname = "model_weights.pth"

        print('saving checkpoint...')
        checkpoint = {
            'v': self.ac.v.state_dict(),
            'pi': self.ac.pi.state_dict(),
            'v_optimizer': self.v_optimizer.state_dict(),
            'pi_optimizer': self.pi_optimizer.state_dict()
        }
        self.env.save(os.path.join(self.save_dir, "env.json"))
        torch.save(checkpoint, os.path.join(self.save_dir, _fname))
        print(f"checkpoint saved at {os.path.join(self.save_dir, _fname)}")

    def load_weights(self, best=True):
        '''
        Load the model weights and replay buffer from self.save_dir
        Args:
            best (bool): If True, save from the weights file with the best mean episode reward
        '''
        if best:
            fname = "best.pth"
        else:
            fname = "model_weights.pth"
        checkpoint_path = os.path.join(self.save_dir, fname)
        if os.path.isfile(checkpoint_path):
            key = 'cuda' if torch.cuda.is_available() else 'cpu'
            checkpoint = torch.load(checkpoint_path, map_location=key)
            self.ac.v.load_state_dict(
                sanitise_state_dict(checkpoint['v'], self.ngpu > 1))
            self.ac.pi.load_state_dict(
                sanitise_state_dict(checkpoint['pi'], self.ngpu > 1))
            self.v_optimizer.load_state_dict(
                sanitise_state_dict(checkpoint['v_optimizer'], self.ngpu > 1))
            self.pi_optimizer.load_state_dict(
                sanitise_state_dict(checkpoint['pi_optimizer'], self.ngpu > 1))

            env_path = os.path.join(self.save_dir, "env.json")
            if os.path.isfile(env_path):
                self.env = self.env.load(env_path)
                print("Environment loaded")
            print('checkpoint loaded at {}'.format(checkpoint_path))
        else:
            raise OSError("Checkpoint file not found.")

    def learn_one_trial(self, timesteps, trial_num):
        ep_rets = []
        epochs = int((timesteps / self.steps_per_epoch) + 0.5)
        print(
            "Rounded off to {} epochs with {} steps per epoch, total {} timesteps"
            .format(epochs, self.steps_per_epoch,
                    epochs * self.steps_per_epoch))
        start_time = time.time()
        obs, ep_ret, ep_len = self.env.reset(), 0, 0
        ep_num = 0
        for epoch in tqdm(range(epochs)):
            for t in range(self.steps_per_epoch):
                # step the environment
                a, v, logp = self.ac.step(
                    torch.as_tensor(obs, dtype=torch.float32).to(self.device))
                next_obs, reward, done, _ = self.env.step(a)
                ep_ret += reward
                ep_len += 1

                # Add experience to buffer
                self.buffer.store(obs, a, reward, v, logp)

                obs = next_obs
                timeout = ep_len == self.max_ep_len
                terminal = done or timeout
                epoch_ended = t == self.steps_per_epoch - 1

                # End of trajectory/episode handling
                if terminal or epoch_ended:
                    if timeout or epoch_ended:
                        _, v, _ = self.ac.step(
                            torch.as_tensor(obs, dtype=torch.float32).to(
                                self.device))
                    else:
                        v = 0

                    ep_num += 1
                    self.logger.store(EpRet=ep_ret, EpLen=ep_len)
                    self.tensorboard_logger.add_scalar(
                        'episodic_return_train', ep_ret,
                        epoch * self.steps_per_epoch + (t + 1))
                    self.buffer.finish_path(v)
                    obs, ep_ret, ep_len = self.env.reset(), 0, 0
                    # Retrieve training reward
                    x, y = self.logger.load_results(["EpLen", "EpRet"])
                    if len(x) > 0:
                        # Mean training reward over the last 50 episodes
                        mean_reward = np.mean(y[-50:])

                        # New best model
                        if mean_reward > self.best_mean_reward:
                            # print("Num timesteps: {}".format(timestep))
                            print(
                                "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}"
                                .format(self.best_mean_reward, mean_reward))

                            self.best_mean_reward = mean_reward
                            self.save_weights(fname=f"best_{trial_num}.pth")

                        if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold:
                            print("Solved Environment, stopping iteration...")
                            return

            # update value function and PPO policy update
            self.update()
            self.logger.dump()
            if self.save_freq > 0 and epoch % self.save_freq == 0:
                self.save_weights(fname=f"latest_{trial_num}.pth")

    def learn(self, timesteps, num_trials=1):
        '''
        Function to learn using PPO.
        Args:
            timesteps (int): number of timesteps to train for
            num_trials (int): Number of times to train the agent
        '''
        self.env.training = True
        best_reward_trial = -np.inf
        for trial in range(num_trials):
            self.tensorboard_logger = SummaryWriter(
                log_dir=os.path.join(self.tensorboard_logdir, f'{trial+1}'))
            self.learn_one_trial(timesteps, trial + 1)

            if self.best_mean_reward > best_reward_trial:
                best_reward_trial = self.best_mean_reward
                self.save_weights(best=True)

            self.logger.reset()
            self.reinit_network()
            print()
            print(f"Trial {trial+1}/{num_trials} complete")

    def test(self, timesteps=None, render=False, record=False):
        '''
        Test the agent in the environment
        Args:
            render (bool): If true, render the image out for user to see in real time
            record (bool): If true, save the recording into a .gif file at the end of episode
            timesteps (int): number of timesteps to run the environment for. Default None will run to completion
        Return:
            Ep_Ret (int): Total reward from the episode
            Ep_Len (int): Total length of the episode in terms of timesteps
        '''
        self.env.training = False
        if render:
            self.env.render('human')
        obs, done, ep_ret, ep_len = self.env.reset(), False, 0, 0
        img = []
        if record:
            img.append(self.env.render('rgb_array'))

        if timesteps is not None:
            for i in range(timesteps):
                # Take stochastic action with policy network
                action, _, _ = self.ac.step(
                    torch.as_tensor(obs, dtype=torch.float32).to(self.device))
                obs, reward, done, _ = self.env.step(action)
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += reward
                ep_len += 1
        else:
            while not (done or (ep_len == self.max_ep_len)):
                # Take stochastic action with policy network
                action, _, _ = self.ac.step(
                    torch.as_tensor(obs, dtype=torch.float32).to(self.device))
                obs, reward, done, _ = self.env.step(action)
                if record:
                    img.append(self.env.render('rgb_array'))
                else:
                    self.env.render()
                ep_ret += reward
                ep_len += 1

        self.env.training = True
        if record:
            imageio.mimsave(
                f'{os.path.join(self.save_dir, "recording.gif")}',
                [np.array(img) for i, img in enumerate(img) if i % 2 == 0],
                fps=29)

        return ep_ret, ep_len