Esempio n. 1
0
class BCAgent(POLOAgent):
    """
    An agent extending upon POLO that uses behavior cloning on the planner
    predicted actions as a prior to MPC.
    """
    def __init__(self, params):
        super(BCAgent, self).__init__(params)

        # Initialize policy network
        pol_params = self.params['p-bc']['pol_params']
        pol_params['input_size'] = self.N
        pol_params['output_size'] = self.M
        if 'final_activation' not in pol_params:
            pol_params['final_activation'] = torch.tanh

        self.pol = MLP(pol_params)

        # Create policy optimizer
        ppar = self.params['p-bc']['pol_optim']
        self.pol_optim = torch.optim.Adam(self.pol.parameters(),
                                          lr=ppar['lr'],
                                          weight_decay=ppar['reg'])

        # Use a replay buffer that will save planner actions
        self.pol_buf = ReplayBuffer(self.N, self.M,
                                    self.params['p-bc']['buf_size'])

        # Logging (store cum_rew, cum_emp_rew)
        self.hist['pols'] = np.zeros((self.T, 2))

        self.has_pol = True

        self.pol_cache = ()

    def get_action(self):
        """
        BCAgent generates a planned trajectory using the behavior-cloned policy
        and then optimizes it via MPC.
        """
        self.pol.eval()

        # Run a rollout using the policy starting from the current state
        infos = self.get_traj_info()

        self.hist['pols'][self.time] = infos[3:5]
        self.pol_cache = (infos[0], infos[2])

        self.prior_actions = infos[1]

        # Generate trajectory via MPC with the prior actions as a prior
        action = super(BCAgent, self).get_action(prior=self.prior_actions)

        # Add final planning trajectory to BC buffer
        fin_states, fin_rews = self.cache[2], self.cache[3]
        fin_states = np.concatenate(([self.prev_obs], fin_states[1:]))
        pb_pct = self.params['p-bc']['pb_pct']
        pb_len = int(pb_pct * fin_states.shape[0])
        for t in range(pb_len):
            self.pol_buf.update(fin_states[t], fin_states[t + 1], fin_rews[t],
                                self.planned_actions[t], False)

        return action

    def do_updates(self):
        """
        Learn from the saved buffer of planned actions.
        """
        super(BCAgent, self).do_updates()

        if self.time % self.params['p-bc']['update_freq'] == 0:
            self.update_pol()

    def update_pol(self):
        """
        Update the policy via BC on the planner actions.
        """
        self.pol.train()

        params = self.params['p-bc']

        # Generate batches for training
        size = min(self.pol_buf.size, self.pol_buf.total_in)
        num_inds = params['batch_size'] * params['grad_steps']
        inds = np.random.randint(0, size, size=num_inds)

        states = self.pol_buf.buffer['s'][inds]
        acts = self.pol_buf.buffer['a'][inds]

        states = torch.tensor(states, dtype=self.dtype)
        actions = torch.tensor(acts, dtype=self.dtype)

        for i in range(params['grad_steps']):
            bi, ei = i * params['batch_size'], (i + 1) * params['batch_size']

            # Train based on L2 distance between actions and predictions
            preds = self.pol.forward(states[bi:ei])
            preds = torch.squeeze(preds, dim=-1)
            targets = torch.squeeze(actions[bi:ei], dim=-1)

            loss = torch.nn.functional.mse_loss(preds, targets)

            self.pol_optim.zero_grad()
            loss.backward()
            self.pol_optim.step()

    def get_traj_info(self):
        """
        Run the policy for a full trajectory and return details about the
        trajectory.
        """
        env_state = self.env.sim.get_state() if self.mujoco else None

        infos = traj.eval_traj(copy.deepcopy(self.env),
                               env_state,
                               self.prev_obs,
                               mujoco=self.mujoco,
                               perturb=self.perturb,
                               H=self.H,
                               gamma=self.gamma,
                               act_mode='deter',
                               pt=(self.pol, 0),
                               terminal=self.val_ens,
                               tvel=self.tvel)

        return infos

    def print_logs(self):
        """
        BC-specific logging information.
        """
        bi, ei = super(BCAgent, self).print_logs()

        self.print('BC metrics', mode='head')

        self.print('policy traj rew', self.hist['pols'][self.time - 1][0])
        self.print('policy traj emp rew', self.hist['pols'][self.time - 1][1])

        return bi, ei

    def test_policy(self):
        """
        Run the BC action selection mechanism.
        """
        env = copy.deepcopy(self.env)
        obs = env.reset()

        if self.tvel is not None:
            env.set_target_vel(self.tvel)
            obs = env._get_obs()

        env_state = env.sim.get_state() if self.mujoco else None
        infos = traj.eval_traj(env,
                               env_state,
                               obs,
                               mujoco=self.mujoco,
                               perturb=self.perturb,
                               H=self.eval_len,
                               gamma=1,
                               act_mode='deter',
                               pt=(self.pol, 0),
                               tvel=self.tvel)

        self.hist['pol_test'][self.time] = infos[3]
Esempio n. 2
0
class POLOAgent(MPCAgent):
    """
    MPC-based agent that uses the Plan Online, Learn Offline (POLO) framework
    (Lowrey et. al. 2018) for trajectory optimization.
    """
    def __init__(self, params):
        super(POLOAgent, self).__init__(params)
        self.H_backup = self.params['polo']['H_backup']

        # Create ensemble of value functions
        model_params = params['polo']['ens_params']['model_params']
        model_params['input_size'] = self.N
        model_params['output_size'] = 1

        params['polo']['ens_params']['dtype'] = self.dtype
        params['polo']['ens_params']['device'] = self.device

        self.val_ens = Ensemble(self.params['polo']['ens_params'])

        # Learn from replay buffer
        self.polo_buf = ReplayBuffer(self.N, self.M,
                                     self.params['polo']['buf_size'])

        # Value (from forward), value mean, value std
        self.hist['vals'] = np.zeros((self.T, 3))

    def get_action(self, prior=None):
        """
        POLO selects action based on MPC optimization with an optimistic
        terminal value function.
        """
        self.val_ens.eval()

        # Get value of current state
        s = torch.tensor(self.prev_obs, dtype=self.dtype)
        s = s.to(device=self.device)
        current_val = self.val_ens.forward(s)[0]
        current_val = torch.squeeze(current_val, -1)
        current_val = current_val.detach().cpu().numpy()

        # Get prediction of every function in ensemble
        preds = self.val_ens.get_preds_np(self.prev_obs)

        # Log information from value function
        self.hist['vals'][self.time] = \
            np.array([current_val, np.mean(preds), np.std(preds)])

        # Run MPC to get action
        act = super(POLOAgent, self).get_action(terminal=self.val_ens,
                                                prior=prior)

        return act

    def action_taken(self, prev_obs, obs, rew, done, ifo):
        """
        Update buffer for value function learning.
        """
        self.polo_buf.update(prev_obs, obs, rew, done)

    def do_updates(self):
        """
        POLO learns a value function from its past true history of interactions
        with the environment.
        """
        super(POLOAgent, self).do_updates()
        if self.time % self.params['polo']['update_freq'] == 0:
            self.val_ens.update_from_buf(self.polo_buf,
                                         self.params['polo']['grad_steps'],
                                         self.params['polo']['batch_size'],
                                         self.params['polo']['H_backup'],
                                         self.gamma)

    def print_logs(self):
        """
        POLO-specific logging information.
        """
        bi, ei = super(POLOAgent, self).print_logs()

        self.print('POLO metrics', mode='head')

        self.print('current state val', self.hist['vals'][self.time - 1][0])
        self.print('current state std', self.hist['vals'][self.time - 1][2])

        return bi, ei