Example #1
0
    def cost_function(self, x, it=0):
        with torch.no_grad():
            plans, obs = self.create_particles(x, self._observation)
            returns = self.get_plan_values(obs,
                                           plans).view(self.num_rollouts, -1)
            weighted_returns = self.get_weighted_returns(returns)
            costs = -ptu.get_numpy(weighted_returns)

        if self._need_to_update_diagnostics:
            self.diagnostics.update(
                create_stats_ordered_dict(
                    'Iteration %d Returns' % it,
                    ptu.get_numpy(weighted_returns),
                    always_show_all_stats=True,
                ))

            self.diagnostics.update(
                create_stats_ordered_dict(
                    'Iteration %d Particle Stds' % it,
                    np.std(ptu.get_numpy(returns), axis=-1),
                    always_show_all_stats=True,
                ))

            variance = weighted_returns.var()
            particle_variance = returns.var(dim=-1)
            self.diagnostics['Return Leftover Variance'] = \
                ptu.get_numpy(variance - particle_variance.mean()).mean()

        return costs
Example #2
0
 def get_diagnostics(self):
     stats = OrderedDict()
     stats.update(
         create_stats_ordered_dict(
             'mean',
             ptu.get_numpy(self.mean),
             # exclude_max_min=True,
         ))
     stats.update(
         create_stats_ordered_dict(
             'std',
             ptu.get_numpy(self.distribution.stddev),
         ))
     return stats
Example #3
0
 def get_diagnostics(self):
     stats = OrderedDict()
     stats.update(
         create_stats_ordered_dict(
             'mean',
             ptu.get_numpy(self.mean),
         ))
     stats.update(
         create_stats_ordered_dict('normal/std',
                                   ptu.get_numpy(self.normal_std)))
     stats.update(
         create_stats_ordered_dict(
             'normal/log_std',
             ptu.get_numpy(torch.log(self.normal_std)),
         ))
     return stats
Example #4
0
    def predict_transition(self, obs, actions, infos):
        if self.sampling_mode == 'ts':
            preds = self._predict_transition_ts(obs, actions, infos)
        elif self.sampling_mode == 'uniform':
            preds = self._predict_transition_uniform(obs, actions, infos)
        else:
            raise NotImplementedError('MPC sampling_mode not recognized')

        next_obs, rewards, dones = obs + preds[:, 2:], preds[:,
                                                             0], preds[:,
                                                                       1] > 0.5
        if self.reward_func is not None:
            given_rewards = self.reward_func(obs,
                                             actions,
                                             next_obs,
                                             num_timesteps=self.num_timesteps)
            self.diagnostics.update(
                create_stats_ordered_dict(
                    'Reward Squared Error',
                    ptu.get_numpy((given_rewards - rewards)**2),
                    always_show_all_stats=True,
                ))
            rewards = given_rewards

        return next_obs, rewards, dones
Example #5
0
    def get_plan_values_batch(self, obs, plans):
        """
        Get corresponding values of the plans (higher corresponds to better plans). Classes
        that don't want to plan over actions or use trajectory sampling can reimplement
        convert_plans_to_actions (& convert_plan_to_action) and/or predict_transition.
        plans is input as as torch (horizon_length, num_particles (total), plan_dim).
        We maintain trajectory infos as torch (n_part, info_dim (ex. obs_dim)).
        """

        if self.use_gt_model:
            return self.get_plan_values_batch_gt(obs, plans)

        n_part = plans.shape[
            1]  # *total* number of particles, NOT num_particles

        discount = 1
        returns, dones, infos = ptu.zeros(n_part), ptu.zeros(n_part), dict()

        # The effective planning horizon is self.horizon * self.repeat_length
        for t in range(self.horizon):
            for k in range(self.repeat_length):
                cur_actions = self.convert_plans_to_actions(obs, plans[t])
                obs, cur_rewards, cur_dones = self.predict_transition(
                    obs, cur_actions, infos)
                returns += discount * (1 - dones) * cur_rewards
                discount *= self.discount
                if self.predict_terminal:
                    dones = torch.max(dones, cur_dones.float())

        self.diagnostics.update(
            create_stats_ordered_dict(
                'MPC Termination',
                ptu.get_numpy(dones),
            ))

        if self.value_func is not None:
            terminal_values = self.value_func(
                obs, **self.value_func_kwargs).view(-1)
            returns += discount * (1 - dones) * terminal_values

            self.diagnostics.update(
                create_stats_ordered_dict(
                    'MPC Terminal Values',
                    ptu.get_numpy(terminal_values),
                ))

        return returns
Example #6
0
 def get_diagnostics(self):
     path_lens = [len(path['actions']) for path in self._epoch_paths]
     stats = OrderedDict([
         ('num steps total', self._num_steps_total),
         ('num paths total', self._num_paths_total),
     ])
     stats.update(
         create_stats_ordered_dict(
             "path length",
             path_lens,
             always_show_all_stats=True,
         ))
     return stats
Example #7
0
    def reward_postprocessing(self,
                              rewards,
                              reward_kwargs=None,
                              *args,
                              **kwargs):
        if self.disagreement_threshold is None:
            return super().reward_postprocessing(rewards)
        rewards, diagnostics = super().reward_postprocessing(rewards)

        disagreements = reward_kwargs['disagreements']
        violated = disagreements > self.disagreement_threshold
        rewards[violated] = self.reward_bounds[0]

        if self._need_to_update_eval_statistics:
            diagnostics.update(
                create_stats_ordered_dict(
                    'Model Disagreement',
                    disagreements,
                ))
            diagnostics['Pct of Timesteps over Disagreement Cutoff'] = np.mean(
                violated)

        return rewards, diagnostics
Example #8
0
    def train_from_paths(self, paths):
        """
        Path preprocessing; have to copy so we don't modify when paths are used elsewhere
        """

        paths = copy.deepcopy(paths)
        for path in paths:
            # Other places like to have an extra dimension so that all arrays are 2D
            path['terminals'] = np.squeeze(path['terminals'], axis=-1)
            path['rewards'] = np.squeeze(path['rewards'], axis=-1)

            # Reward normalization; divide by std of reward in replay buffer
            path['rewards'] = np.clip(
                path['rewards'] / (self._reward_std + 1e-3), -10, 10)

        obs, actions = [], []
        for path in paths:
            obs.append(path['observations'])
            actions.append(path['actions'])
        obs = np.concatenate(obs, axis=0)
        actions = np.concatenate(actions, axis=0)

        obs_tensor, act_tensor = ptu.from_numpy(obs), ptu.from_numpy(actions)
        """
        Policy training loop
        """

        old_policy = copy.deepcopy(self.policy)
        with torch.no_grad():
            log_probs_old = old_policy.get_log_probs(
                obs_tensor, act_tensor).squeeze(dim=-1)

        rem_value_epochs = self.num_epochs
        for epoch in range(self.num_policy_epochs):
            """
            Recompute advantages at the beginning of each epoch. This allows for advantages
                to utilize the latest value function.
            Note: while this is not present in most implementations, it is recommended
                  by Andrychowicz et al. 2020.
            """

            path_functions.calculate_baselines(paths, self.value_func)
            path_functions.calculate_returns(paths, self.discount)
            path_functions.calculate_advantages(
                paths,
                self.discount,
                self.gae_lambda,
                self.normalize_advantages,
            )

            advantages, returns, baselines = [], [], []
            for path in paths:
                advantages = np.append(advantages, path['advantages'])
                returns = np.append(returns, path['returns'])

            if epoch == 0 and self._need_to_update_eval_statistics:
                with torch.no_grad():
                    values = torch.squeeze(self.value_func(obs_tensor), dim=-1)
                    values_np = ptu.get_numpy(values)
                first_val_loss = ((returns - values_np)**2).mean()

            old_params = self.policy.get_param_values()

            num_policy_steps = len(advantages) // self.policy_batch_size
            for _ in range(num_policy_steps):
                if num_policy_steps == 1:
                    batch = dict(
                        observations=obs,
                        actions=actions,
                        advantages=advantages,
                    )
                else:
                    batch = ppp.sample_batch(
                        self.policy_batch_size,
                        observations=obs,
                        actions=actions,
                        advantages=advantages,
                    )
                policy_loss, kl = self.train_policy(batch, old_policy)

            with torch.no_grad():
                log_probs = self.policy.get_log_probs(
                    obs_tensor, act_tensor).squeeze(dim=-1)
            kl = (log_probs_old - log_probs).mean()

            if (self.target_kl is not None
                    and kl > 1.5 * self.target_kl) or (kl != kl):
                if epoch > 0 or kl != kl:  # nan check
                    self.policy.set_param_values(old_params)
                break

            num_value_steps = len(advantages) // self.value_batch_size
            for i in range(num_value_steps):
                batch = ppp.sample_batch(
                    self.value_batch_size,
                    observations=obs,
                    targets=returns,
                )
                value_loss = self.train_value(batch)
            rem_value_epochs -= 1

        # Ensure the value function is always updated for the maximum number
        # of epochs, regardless of if the policy wants to terminate early.
        for _ in range(rem_value_epochs):
            num_value_steps = len(advantages) // self.value_batch_size
            for i in range(num_value_steps):
                batch = ppp.sample_batch(
                    self.value_batch_size,
                    observations=obs,
                    targets=returns,
                )
                value_loss = self.train_value(batch)

        if self._need_to_update_eval_statistics:
            with torch.no_grad():
                _, _, _, log_pi, *_ = self.policy(obs_tensor,
                                                  return_log_prob=True)
                values = torch.squeeze(self.value_func(obs_tensor), dim=-1)
                values_np = ptu.get_numpy(values)

            errors = returns - values_np
            explained_variance = 1 - (np.var(errors) / np.var(returns))
            value_loss = errors**2

            self.eval_statistics['Num Epochs'] = epoch + 1

            self.eval_statistics['Policy Loss'] = ptu.get_numpy(
                policy_loss).mean()
            self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl).mean()
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantages',
                    advantages,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Returns',
                    returns,
                ))

            self.eval_statistics['Value Loss'] = value_loss.mean()
            self.eval_statistics['First Value Loss'] = first_val_loss
            self.eval_statistics[
                'Value Explained Variance'] = explained_variance
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Values',
                    ptu.get_numpy(values),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Value Squared Errors',
                    value_loss,
                ))
Example #9
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = -q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
        self._n_train_steps_total += 1
Example #10
0
    def train_from_torch(self, batch):
        # We only use the original batch to get the batch size for policy training
        """
        Generate synthetic data using dynamics model
        """
        if self._n_train_steps_total % self.rollout_generation_freq == 0:
            rollout_len = self.rollout_len_func(self._n_train_steps_total)
            total_samples = self.rollout_generation_freq * self.num_model_rollouts

            num_samples, generated_rewards, terminated = 0, np.array([]), []
            while num_samples < total_samples:
                batch_samples = min(self.rollout_batch_size,
                                    total_samples - num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = self.sample_paths(start_states, rollout_len)

                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    num_samples += len(path['observations'])
                    generated_rewards = np.concatenate(
                        [generated_rewards, path['rewards'][:, 0]])
                    terminated.append(path['terminals'][-1, 0] > 0.5)

                if num_samples >= total_samples:
                    break

            gt.stamp('generating rollouts', unique=False)
        """
        Update policy on both real and generated data
        """

        batch_size = batch['observations'].shape[0]
        n_real_data = int(self.real_data_pct * batch_size)
        n_generated_data = batch_size - n_real_data

        for _ in range(self.num_policy_updates):
            batch = self.replay_buffer.random_batch(n_real_data)
            generated_batch = self.generated_data_buffer.random_batch(
                n_generated_data)

            for k in ('rewards', 'terminals', 'observations', 'actions',
                      'next_observations'):
                batch[k] = np.concatenate((batch[k], generated_batch[k]),
                                          axis=0)
                batch[k] = ptu.from_numpy(batch[k])

            self.policy_trainer.train_from_torch(batch)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics and self._n_train_steps_total % self.rollout_generation_freq == 0:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['MBPO Rollout Length'] = rollout_len
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'MBPO Reward Predictions',
                    generated_rewards,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'MBPO Rollout Terminations',
                    np.array(terminated).astype(float),
                ))

        self._n_train_steps_total += 1
Example #11
0
    def sample_paths(self, start_states, rollout_len):
        if self.sampling_mode == 'uniform':
            # Sample uniformly from a model of the ensemble (original MBPO; Janner et al. 2019)
            paths = mrf.policy(
                self.dynamics_model,
                self.policy_trainer.policy,
                start_states,
                max_path_length=rollout_len,
            )

        elif self.sampling_mode == 'mean_disagreement':
            # Sample with penalty for disagreement of the mean (MOReL; Kidambi et al. 2020)
            paths, disagreements = mrf.policy_with_disagreement(
                self.dynamics_model,
                self.policy_trainer.policy,
                start_states,
                max_path_length=rollout_len,
                disagreement_type='mean',
            )
            disagreements = ptu.get_numpy(disagreements)

            threshold, penalty = self.sampling_kwargs[
                'threshold'], self.sampling_kwargs['penalty']
            total_penalized, total_transitions = 0, 0
            for i, path in enumerate(paths):
                mask = np.zeros(len(path['rewards']))
                disagreement_values = disagreements[i]
                for t in range(len(path['rewards'])):
                    cur_mask = disagreement_values[t] > threshold
                    if t == 0:
                        mask[t] = cur_mask
                    elif cur_mask or mask[t - 1] > 0.5:
                        mask[t] = 1.
                    else:
                        mask[t] = 0.
                mask = mask.reshape(len(mask), 1)
                path['rewards'] = (1 - mask) * path['rewards'] - mask * penalty
                total_penalized += mask.sum()
                total_transitions += len(path)

            self.eval_statistics[
                'Percent of Transitions Penalized'] = total_penalized / total_transitions
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Disagreement Values',
                    disagreements,
                ))

        elif self.sampling_mode == 'var_disagreement':
            # Sample with penalty for disagreement of the variance (MOPO; Yu et al. 2020)
            paths, disagreements = mrf.policy_with_disagreement(
                self.dynamics_model,
                self.policy_trainer.policy,
                start_states,
                max_path_length=rollout_len,
                disagreement_type='var',
            )
            disagreements = ptu.get_numpy(disagreements)

            reward_penalty = self.sampling_kwargs['reward_penalty']
            for i, path in enumerate(paths):
                path_disagreements = disagreements[
                    i, :len(path['rewards'])].reshape(*path['rewards'].shape)
                path['rewards'] -= reward_penalty * path_disagreements

            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Disagreement Values',
                    disagreements,
                ))

        else:
            raise NotImplementedError

        return paths
Example #12
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy operations.
        """
        if self.policy_pre_activation_weight > 0:
            policy_actions, pre_tanh_value = self.policy(
                obs,
                return_preactivations=True,
            )
            pre_activation_policy_loss = ((pre_tanh_value**2).sum(
                dim=1).mean())
            q_output = self.qf(obs, policy_actions)
            raw_policy_loss = -q_output.mean()
            policy_loss = (
                raw_policy_loss +
                pre_activation_policy_loss * self.policy_pre_activation_weight)
        else:
            policy_actions = self.policy(obs)
            q_output = self.qf(obs, policy_actions)
            raw_policy_loss = policy_loss = -q_output.mean()
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        # speed up computation by not backpropping these gradients
        next_actions.detach()
        target_q_values = self.target_qf(
            next_obs,
            next_actions,
        )
        q_target = rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
        q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)
        q_pred = self.qf(obs, actions)
        bellman_errors = (q_pred - q_target)**2
        raw_qf_loss = self.qf_criterion(q_pred, q_target)

        if self.qf_weight_decay > 0:
            reg_loss = self.qf_weight_decay * sum(
                torch.sum(param**2)
                for param in self.qf.regularizable_parameters())
            qf_loss = raw_qf_loss + reg_loss
        else:
            qf_loss = raw_qf_loss
        """
        Update Networks
        """

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self._update_target_networks()
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics['Raw Policy Loss'] = np.mean(
                ptu.get_numpy(raw_policy_loss))
            self.eval_statistics['Preactivation Policy Loss'] = (
                self.eval_statistics['Policy Loss'] -
                self.eval_statistics['Raw Policy Loss'])
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors',
                    ptu.get_numpy(bellman_errors),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
        self._n_train_steps_total += 1
Example #13
0
    def train_from_torch(self, batch):
        obs = batch['observations']
        next_obs = batch['next_observations']
        actions = batch['actions']
        rewards = batch['rewards']
        terminals = batch.get('terminals', ptu.zeros(rewards.shape[0], 1))

        """
        Policy and Alpha Loss
        """
        _, policy_mean, policy_logstd, *_ = self.policy(obs)
        dist = TanhNormal(policy_mean, policy_logstd.exp())
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.sum(dim=-1, keepdims=True)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        _, next_policy_mean, next_policy_logstd, *_ = self.policy(next_obs)
        next_dist = TanhNormal(next_policy_mean, next_policy_logstd.exp())
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.sum(dim=-1, keepdims=True)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        future_values = (1. - terminals) * self.discount * target_q_values
        q_target = self.reward_scale * rewards + future_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self._n_train_steps_total += 1

        self.try_update_target_networks()

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            policy_loss = (log_pi - q_new_actions).mean()
            policy_avg_std = torch.exp(policy_logstd).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_logstd),
            ))
            self.eval_statistics['Policy std'] = np.mean(ptu.get_numpy(policy_avg_std))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

        self._n_train_steps_total += 1
Example #14
0
    def train_from_buffer(self, reward_kwargs=None):
        """
        Compute intrinsic reward: approximate lower bound to I(s'; z | s)
        """

        if self.relabel_rewards:

            rewards, (
                logp, logp_altz,
                denom), reward_diagnostics = self.calculate_intrinsic_rewards(
                    self._obs[:self._cur_replay_size],
                    self._next_obs[:self._cur_replay_size],
                    self._latents[:self._cur_replay_size],
                    reward_kwargs=reward_kwargs)
            orig_rewards = rewards.copy()
            rewards, postproc_dict = self.reward_postprocessing(
                rewards, reward_kwargs=reward_kwargs)
            reward_diagnostics.update(postproc_dict)
            self._rewards[:self._cur_replay_size] = np.expand_dims(rewards,
                                                                   axis=-1)

            gt.stamp('intrinsic reward calculation', unique=False)
        """
        Train policy
        """

        state_latents = np.concatenate([self._obs, self._latents],
                                       axis=-1)[:self._cur_replay_size]
        next_state_latents = np.concatenate(
            [self._true_next_obs, self._latents],
            axis=-1)[:self._cur_replay_size]

        for _ in range(self.num_policy_updates):
            batch = ppp.sample_batch(
                self.policy_batch_size,
                observations=state_latents,
                next_observations=next_state_latents,
                actions=self._actions[:self._cur_replay_size],
                rewards=self._rewards[:self._cur_replay_size],
            )
            batch = ptu.np_to_pytorch_batch(batch)
            self.policy_trainer.train_from_torch(batch)

        gt.stamp('policy training', unique=False)
        """
        Diagnostics
        """

        if self._need_to_update_eval_statistics:
            self.eval_statistics.update(self.policy_trainer.eval_statistics)

            if self.relabel_rewards:
                self.eval_statistics.update(reward_diagnostics)

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Log Pis',
                        logp,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Alt Log Pis',
                        logp_altz,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Reward Denominator',
                        denom,
                    ))

                # Adjustment so intrinsic rewards are over last epoch
                if self._ptr < self._epoch_size:
                    if self._ptr == 0:
                        inds = np.r_[len(rewards) -
                                     self._epoch_size:len(rewards)]
                    else:
                        inds = np.r_[0:self._ptr,
                                     len(rewards) - self._ptr:len(rewards)]
                else:
                    inds = np.r_[self._ptr - self._epoch_size:self._ptr]

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Original)',
                        orig_rewards[inds],
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Processed)',
                        rewards[inds],
                    ))

        self._n_train_steps_total += 1
Example #15
0
    def train_from_torch(self, batch):

        self._train_calls += 1
        if self._train_calls % self.train_every > 0:
            return

        rollout_len = self.rollout_len_func(self._n_train_steps_total)
        num_model_rollouts = max(self.num_model_samples // rollout_len, 1)
        self.eval_statistics['Rollout Length'] = rollout_len

        real_batch = self.replay_buffer.random_batch(num_model_rollouts)
        start_states = real_batch['observations']
        latents = self.generate_latents(start_states)

        observations = np.zeros((self.num_model_samples, self.obs_dim))
        next_observations = np.zeros((self.num_model_samples, self.obs_dim))
        actions = np.zeros((self.num_model_samples, self.action_dim))
        unfolded_latents = np.zeros((self.num_model_samples, self.latent_dim))
        disagreements = np.zeros(self.num_model_samples)

        num_samples, b_ind, num_traj = 0, 0, 0
        while num_samples < self.num_model_samples:
            e_ind = b_ind + 4192 // rollout_len
            with torch.no_grad():
                paths, path_disagreements = self.generate_paths(
                    dynamics_model=self.dynamics_model,
                    control_policy=self.control_policy,
                    start_states=start_states[b_ind:e_ind],
                    latents=ptu.from_numpy(latents[b_ind:e_ind]),
                    rollout_len=rollout_len,
                )

            b_ind = e_ind

            path_disagreements = ptu.get_numpy(path_disagreements)
            for i, path in enumerate(paths):
                clipped_len = min(
                    len(path['observations'] - (self.empowerment_horizon - 1)),
                    self.num_model_samples - num_samples)
                bi, ei = num_samples, num_samples + clipped_len

                if self.empowerment_horizon > 1:
                    path['observations'] = path['observations'][:-(
                        self.empowerment_horizon - 1)]
                    path['next_observations'] = path['next_observations'][(
                        self.empowerment_horizon -
                        1):(self.empowerment_horizon - 1) + clipped_len]
                    path['actions'] = path['actions'][:-(
                        self.empowerment_horizon - 1)]

                observations[bi:ei] = path['observations'][:clipped_len]
                next_observations[bi:ei] = path[
                    'next_observations'][:clipped_len]
                actions[bi:ei] = path['actions'][:clipped_len]
                unfolded_latents[bi:ei] = latents[num_traj:num_traj + 1]
                disagreements[bi:ei] = path_disagreements[i, :clipped_len]

                num_samples += clipped_len
                num_traj += 1

                if num_samples >= self.num_model_samples:
                    break

        gt.stamp('generating rollouts', unique=False)

        if not self.relabel_rewards:
            rewards, (
                logp, logp_altz,
                denom), reward_diagnostics = self.calculate_intrinsic_rewards(
                    observations, next_observations, unfolded_latents)
            orig_rewards = rewards.copy()
            rewards, postproc_dict = self.reward_postprocessing(
                rewards, reward_kwargs=dict(disagreements=disagreements))
            reward_diagnostics.update(postproc_dict)

            if self._need_to_update_eval_statistics:
                self.eval_statistics.update(reward_diagnostics)

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Log Pis',
                        logp,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Alt Log Pis',
                        logp_altz,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Reward Denominator',
                        denom,
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Original)',
                        orig_rewards,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Processed)',
                        rewards,
                    ))

            gt.stamp('intrinsic reward calculation', unique=False)

        if self._need_to_update_eval_statistics:
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Latents',
                    latents,
                ))

        for t in range(self.num_model_samples):
            self.add_sample(
                observations[t],
                next_observations[t],
                next_observations[t],  # fix this
                actions[t],
                unfolded_latents[t],
                disagreement=disagreements[t],
            )

        gt.stamp('policy training', unique=False)

        self.train_discriminator(observations, next_observations,
                                 unfolded_latents)

        reward_kwargs = dict(
            disagreements=self._modeL_disagreements[:self._cur_replay_size])
        self.train_from_buffer(reward_kwargs=reward_kwargs)