Beispiel #1
0
    def __init__(self, params):
        self.params = params
        self.kappa = self.params['kappa']

        self.dtype = self.params['dtype']
        self.device = self.params['device']

        self.models = []
        self.priors = []
        self.optims = []

        for i in range(self.params['ens_size']):
            model = MLP(self.params['model_params']).to(device=self.device)
            self.models.append(model)
            self.optims.append(
                torch.optim.Adam(model.parameters(),
                                 lr=self.params['lr'],
                                 weight_decay=self.params['reg']))

            prior = MLP(self.params['model_params']).to(device=self.device)
            prior.eval()
            self.priors.append(prior)
Beispiel #2
0
class BCAgent(POLOAgent):
    """
    An agent extending upon POLO that uses behavior cloning on the planner
    predicted actions as a prior to MPC.
    """
    def __init__(self, params):
        super(BCAgent, self).__init__(params)

        # Initialize policy network
        pol_params = self.params['p-bc']['pol_params']
        pol_params['input_size'] = self.N
        pol_params['output_size'] = self.M
        if 'final_activation' not in pol_params:
            pol_params['final_activation'] = torch.tanh

        self.pol = MLP(pol_params)

        # Create policy optimizer
        ppar = self.params['p-bc']['pol_optim']
        self.pol_optim = torch.optim.Adam(self.pol.parameters(),
                                          lr=ppar['lr'],
                                          weight_decay=ppar['reg'])

        # Use a replay buffer that will save planner actions
        self.pol_buf = ReplayBuffer(self.N, self.M,
                                    self.params['p-bc']['buf_size'])

        # Logging (store cum_rew, cum_emp_rew)
        self.hist['pols'] = np.zeros((self.T, 2))

        self.has_pol = True

        self.pol_cache = ()

    def get_action(self):
        """
        BCAgent generates a planned trajectory using the behavior-cloned policy
        and then optimizes it via MPC.
        """
        self.pol.eval()

        # Run a rollout using the policy starting from the current state
        infos = self.get_traj_info()

        self.hist['pols'][self.time] = infos[3:5]
        self.pol_cache = (infos[0], infos[2])

        self.prior_actions = infos[1]

        # Generate trajectory via MPC with the prior actions as a prior
        action = super(BCAgent, self).get_action(prior=self.prior_actions)

        # Add final planning trajectory to BC buffer
        fin_states, fin_rews = self.cache[2], self.cache[3]
        fin_states = np.concatenate(([self.prev_obs], fin_states[1:]))
        pb_pct = self.params['p-bc']['pb_pct']
        pb_len = int(pb_pct * fin_states.shape[0])
        for t in range(pb_len):
            self.pol_buf.update(fin_states[t], fin_states[t + 1], fin_rews[t],
                                self.planned_actions[t], False)

        return action

    def do_updates(self):
        """
        Learn from the saved buffer of planned actions.
        """
        super(BCAgent, self).do_updates()

        if self.time % self.params['p-bc']['update_freq'] == 0:
            self.update_pol()

    def update_pol(self):
        """
        Update the policy via BC on the planner actions.
        """
        self.pol.train()

        params = self.params['p-bc']

        # Generate batches for training
        size = min(self.pol_buf.size, self.pol_buf.total_in)
        num_inds = params['batch_size'] * params['grad_steps']
        inds = np.random.randint(0, size, size=num_inds)

        states = self.pol_buf.buffer['s'][inds]
        acts = self.pol_buf.buffer['a'][inds]

        states = torch.tensor(states, dtype=self.dtype)
        actions = torch.tensor(acts, dtype=self.dtype)

        for i in range(params['grad_steps']):
            bi, ei = i * params['batch_size'], (i + 1) * params['batch_size']

            # Train based on L2 distance between actions and predictions
            preds = self.pol.forward(states[bi:ei])
            preds = torch.squeeze(preds, dim=-1)
            targets = torch.squeeze(actions[bi:ei], dim=-1)

            loss = torch.nn.functional.mse_loss(preds, targets)

            self.pol_optim.zero_grad()
            loss.backward()
            self.pol_optim.step()

    def get_traj_info(self):
        """
        Run the policy for a full trajectory and return details about the
        trajectory.
        """
        env_state = self.env.sim.get_state() if self.mujoco else None

        infos = traj.eval_traj(copy.deepcopy(self.env),
                               env_state,
                               self.prev_obs,
                               mujoco=self.mujoco,
                               perturb=self.perturb,
                               H=self.H,
                               gamma=self.gamma,
                               act_mode='deter',
                               pt=(self.pol, 0),
                               terminal=self.val_ens,
                               tvel=self.tvel)

        return infos

    def print_logs(self):
        """
        BC-specific logging information.
        """
        bi, ei = super(BCAgent, self).print_logs()

        self.print('BC metrics', mode='head')

        self.print('policy traj rew', self.hist['pols'][self.time - 1][0])
        self.print('policy traj emp rew', self.hist['pols'][self.time - 1][1])

        return bi, ei

    def test_policy(self):
        """
        Run the BC action selection mechanism.
        """
        env = copy.deepcopy(self.env)
        obs = env.reset()

        if self.tvel is not None:
            env.set_target_vel(self.tvel)
            obs = env._get_obs()

        env_state = env.sim.get_state() if self.mujoco else None
        infos = traj.eval_traj(env,
                               env_state,
                               obs,
                               mujoco=self.mujoco,
                               perturb=self.perturb,
                               H=self.eval_len,
                               gamma=1,
                               act_mode='deter',
                               pt=(self.pol, 0),
                               tvel=self.tvel)

        self.hist['pol_test'][self.time] = infos[3]
Beispiel #3
0
class VPGAgent(Agent):
    """
    An agent running online policy gradient. Calling VPGAgent itself uses
    REINFORCE, but can be subclassed for other policy gradient class algorithms.
    """
    def __init__(self, params):
        super(VPGAgent, self).__init__(params)
        self.H = self.params['pg']['H']
        self.lam = self.params['pg']['lam']

        # Initialize policy network
        pol_params = self.params['pg']['pol_params']
        pol_params['input_size'] = self.N
        pol_params['output_size'] = self.M
        if 'final_activation' not in pol_params:
            pol_params['final_activation'] = torch.tanh

        self.pol = MLP(pol_params)

        # Std's are not dependent on state
        init_log_std = -0.8 * torch.ones(self.M)  # ~0.45
        self.log_std = torch.nn.Parameter(init_log_std, requires_grad=True)

        # Create policy optimizer
        ppar = self.params['pg']['pol_optim']
        self.pol_params = list(self.pol.parameters()) + [self.log_std]
        self.pol_optim = torch.optim.Adam(self.pol_params,
                                          lr=ppar['lr'],
                                          weight_decay=ppar['reg'])

        # Create value function and optimizer
        val_params = self.params['pg']['val_params']
        val_params['input_size'] = self.N
        val_params['output_size'] = 1

        self.val = MLP(val_params)

        vpar = self.params['pg']['val_optim']
        self.val_optim = torch.optim.Adam(self.val.parameters(),
                                          lr=vpar['lr'],
                                          weight_decay=vpar['reg'])

        # Logging
        self.hist['ent'] = np.zeros(self.T)

    def get_dist(self, s):
        """
        Create a pytorch normal distribution from
        the policy network for state s.
        """
        s = torch.tensor(s, dtype=self.dtype)
        mu = self.pol.forward(s)
        std = self.log_std.exp()

        return torch.distributions.Normal(mu, std)

    def get_ent(self):
        """
        Return the current entropy (multivariate Gaussian).
        """
        std = self.log_std.exp()
        tpe = 2 * np.pi * np.e
        return .5 * torch.log(tpe * torch.prod(std))

    def get_action(self):
        """
        Gets action by running policy.
        """
        self.pol.eval()

        if self.params['pg']['run_deterministic']:
            x = torch.tensor(self.prev_obs, dtype=self.dtype)
            act = self.pol.forward(x).detach().cpu().numpy()
        else:
            act = sample_pol(self.pol, self.log_std, self.prev_obs)

        act = np.clip(act, self.params['env']['min_act'],
                      self.params['env']['max_act'])

        self.hist['ent'][self.time] = self.get_ent().detach().cpu().numpy()

        return act

    def do_updates(self):
        """
        Performs actor and critic updates.
        """
        if self.time % self.params['pg']['update_every'] == 0 or self.time == 1:
            plan_time = 0
            H, num_rollouts = self.H, self.params['pg']['num_rollouts']
            for i in range(self.params['pg']['num_iter']):
                # Sample rollouts using ground truth model
                check = time.time()
                rollouts = self.sample_rollouts(H, num_rollouts)
                plan_time += time.time() - check

                # Performs value updates alongside advantage calculation
                rews = self.update_pol(rollouts)

            # Time spent generating rollouts should be considered planning time
            self.hist['plan_time'][self.time - 1] += plan_time
            self.hist['update_time'][self.time - 1] -= plan_time

    def sample_rollouts(self, H, num_rollouts):
        """
        Use traj module to sample rollouts using the policy.
        """
        env_state = self.env.sim.get_state() if self.mujoco else None

        self.pol.eval()
        rollouts = traj.generate_trajectories(
            num_rollouts,
            self.env,
            env_state,
            self.prev_obs,
            mujoco=self.mujoco,
            perturb=self.perturb,
            H=self.H,
            gamma=self.gamma,
            act_mode='gauss',
            pt=(sample_pol, self.pol, self.log_std),
            terminal=None,
            tvel=self.tvel,
            num_cpu=self.params['pg']['num_cpu'])

        return rollouts

    def update_val(self, obs, targets):
        """
        Update value function with MSE loss.
        """
        preds = self.val.forward(obs)
        preds = torch.squeeze(preds, dim=-1)

        loss = torch.nn.functional.mse_loss(targets, preds)

        self.val_optim.zero_grad()
        loss.backward(retain_graph=True)
        self.val_optim.step()

        return loss.item()

    def calc_advs(self, obs, rews, update_vals=True):
        """
        Calculate advantages for use of updating the policy (and updating value
        function). Can either use rewards-to-go or GAE.
        """
        num_rollouts, H = obs.shape[:2]

        self.val.eval()

        if not self.params['pg']['use_gae']:
            # Calculate terminal values
            fin_obs = obs[:, -1]
            fin_vals = self.val.forward(fin_obs)
            fin_vals = torch.squeeze(fin_vals, dim=-1)

            # Calculate rewards-to-go
            rtg = torch.zeros((num_rollouts, H))
            for k in reversed(range(H)):
                if k < H - 1:
                    rtg[:, k] += self.gamma * rtg[:, k + 1]
                else:
                    rtg[:, k] += self.gamma * fin_vals
                rtg[:, k] += rews[:, k]

            if update_vals:
                self.val.train()
                self.update_val(obs, rtg)

            # Normalize advantages for policy gradient
            for k in range(H):
                rtg[:, k] -= torch.mean(rtg[:, k])

            return rtg

        # Generalized Advantage Estimation (GAE)
        prev_obs = torch.tensor(self.prev_obs, dtype=self.dtype)
        orig_val = self.val.forward(prev_obs)
        vals = torch.squeeze(self.val.forward(obs), dim=-1)

        deltas = torch.zeros(rews.shape)
        advs = torch.zeros((num_rollouts, H))

        lg = self.lam * self.gamma
        for k in reversed(range(H)):
            prev_vals = vals[:, k - 1] if k > 0 else orig_val
            deltas[:, k] = self.gamma * vals[:, k] + rews[:, k] - prev_vals

            if k == H - 1:
                advs[:, k] = deltas[:, k]
            else:
                advs[:, k] = lg * advs[:, k + 1] + deltas[:, k]

        advs = advs.detach()

        # Optionally, also update the value functions
        if update_vals:
            self.val.train()

            # It is reasonable to train on advs or deltas
            dvals = advs

            # Have to perform trick to match deltas with prev vals
            fvals = torch.stack([orig_val for _ in range(vals.shape[0])],
                                dim=0)
            rets = torch.cat(
                [fvals + dvals[:, :1], vals[:, :-1] + dvals[:, 1:]], dim=-1)
            fobs = torch.unsqueeze(prev_obs, dim=0)
            fobs = torch.stack([fobs for _ in range(vals.shape[0])], dim=0)
            obs = torch.cat([fobs, obs[:, :-1]], dim=1)

            self.update_val(obs, rets)

        # Normalize advantages for policy gradient
        advs -= torch.mean(advs)
        advs /= 1e-3 + torch.std(advs)

        return advs

    def get_pol_loss(self, logprob, advs, orig_logprob=None):
        """
        For REINFORCE, the policy loss is thelogprobs times the advatanges. It
        is important that the logprobs carry the gradient so that we can
        backpropagate through them in the policy update.
        """
        return torch.mean(logprob * advs)

    def get_logprob(self, pol, log_std, obs, acts):
        """
        Get log probabilities for the actions, keeping the gradients.
        """
        num_rollouts, H = obs.shape[0:2]

        pol.train()

        dist = self.get_dist(obs)
        logprob = dist.log_prob(acts).sum(-1)

        return logprob

    def update_pol(self, rollouts, orig_logprob=None):
        """
        Update the policy on the on-policy rollouts.
        """
        H = rollouts[0][0].shape[0]

        self.pol.train()

        obs = np.zeros((len(rollouts), self.H, self.N))
        acts = np.zeros((len(rollouts), self.H, self.M))
        rews = torch.zeros((len(rollouts), self.H))
        for i in range(len(rollouts)):
            for k in range(self.H):
                obs[i, k] = rollouts[i][0][k]
                acts[i, k] = rollouts[i][1][k]
                rews[i, k] = rollouts[i][2][k]

        obs = torch.tensor(obs, dtype=self.dtype)
        acts = torch.tensor(acts, dtype=self.dtype)

        # Perform updates for multiple steps on the value function
        if self.params['pg']['use_gae']:
            for _ in range(self.params['pg']['val_steps']):
                advs = self.calc_advs(obs, rews, update_vals=True)
        else:
            advs = self.calc_advs(obs, rews, update_vals=False)

        # Perform updates for multiple epochs on the policy
        bsize = self.params['pg']['batch_size']
        for _ in range(self.params['pg']['pol_steps']):
            inds = np.random.permutation(len(rollouts))

            binds = inds[:bsize]
            bobs, bacts = obs[binds], acts[binds]
            brews, badvs = rews[binds], advs[binds]

            if orig_logprob is not None:
                bprobs = orig_logprob[binds]
            else:
                bprobs = None

            # Get a logprob that has gradients
            logprob = self.get_logprob(self.pol, self.log_std, bobs, bacts)
            if not self.continue_updates(logprob, bprobs):
                break

            # Compute policy loss (i.e. gradient ascent)
            J = -self.get_pol_loss(logprob, badvs, orig_logprob=bprobs)

            # Apply entropy bonus
            ent_coef = self.params['pg']['pol_optim']['ent_temp']
            if ent_coef != 0:
                J -= ent_coef * self.get_ent()

            self.pol_optim.zero_grad()
            torch.nn.utils.clip_grad_norm_(self.pol.parameters(),
                                           self.params['pg']['grad_clip'])
            J.backward()
            self.pol_optim.step()

            # Clamp stds to be within set bounds
            log_min = np.log(self.params['pg']['min_std'])
            log_min = torch.tensor(log_min, dtype=self.dtype)
            log_max = np.log(self.params['pg']['max_std'])
            log_max = torch.tensor(log_max, dtype=self.dtype)
            self.log_std.data = torch.clamp(self.log_std.data, log_min,
                                            log_max)

        return rews

    def continue_updates(self, logprob, orig_logprob=None):
        """
        Method for whether or not to continue updates.
        """
        return True

    def print_logs(self):
        """
        Policy gradient-specific logging information.
        """
        bi, ei = super(VPGAgent, self).print_logs()

        self.print('policy gradient metrics', mode='head')

        self.print('entropy avg', np.mean(self.hist['ent'][bi:ei]))
        self.print('sigma avg',
                   np.mean(torch.exp(self.log_std).detach().cpu().numpy()))

        return bi, ei