示例#1
0
    def cost_function(self, x, it=0):
        with torch.no_grad():
            plans, obs = self.create_particles(x, self._observation)
            returns = self.get_plan_values(obs,
                                           plans).view(self.num_rollouts, -1)
            weighted_returns = self.get_weighted_returns(returns)
            costs = -ptu.get_numpy(weighted_returns)

        if self._need_to_update_diagnostics:
            self.diagnostics.update(
                create_stats_ordered_dict(
                    'Iteration %d Returns' % it,
                    ptu.get_numpy(weighted_returns),
                    always_show_all_stats=True,
                ))

            self.diagnostics.update(
                create_stats_ordered_dict(
                    'Iteration %d Particle Stds' % it,
                    np.std(ptu.get_numpy(returns), axis=-1),
                    always_show_all_stats=True,
                ))

            variance = weighted_returns.var()
            particle_variance = returns.var(dim=-1)
            self.diagnostics['Return Leftover Variance'] = \
                ptu.get_numpy(variance - particle_variance.mean()).mean()

        return costs
示例#2
0
 def get_plan_values_batch_gt(self, obs, plans):
     returns = ptu.zeros(plans.shape[1])
     obs, plans = ptu.get_numpy(obs), ptu.get_numpy(plans)
     final_obs = np.copy(obs)
     for i in range(plans.shape[1]):
         returns[i], final_obs[i] = self._get_true_env_value(
             obs[i], plans[:, i])
     if self.value_func is not None:
         returns += (self.discount**(
             self.horizon * self.repeat_length)) * (self.value_func(
                 ptu.from_numpy(final_obs), **self.value_func_kwargs))
     return returns
示例#3
0
 def get_diagnostics(self):
     stats = OrderedDict()
     stats.update(
         create_stats_ordered_dict(
             'mean',
             ptu.get_numpy(self.mean),
             # exclude_max_min=True,
         ))
     stats.update(
         create_stats_ordered_dict(
             'std',
             ptu.get_numpy(self.distribution.stddev),
         ))
     return stats
示例#4
0
    def predict_transition(self, obs, actions, infos):
        if self.sampling_mode == 'ts':
            preds = self._predict_transition_ts(obs, actions, infos)
        elif self.sampling_mode == 'uniform':
            preds = self._predict_transition_uniform(obs, actions, infos)
        else:
            raise NotImplementedError('MPC sampling_mode not recognized')

        next_obs, rewards, dones = obs + preds[:, 2:], preds[:,
                                                             0], preds[:,
                                                                       1] > 0.5
        if self.reward_func is not None:
            given_rewards = self.reward_func(obs,
                                             actions,
                                             next_obs,
                                             num_timesteps=self.num_timesteps)
            self.diagnostics.update(
                create_stats_ordered_dict(
                    'Reward Squared Error',
                    ptu.get_numpy((given_rewards - rewards)**2),
                    always_show_all_stats=True,
                ))
            rewards = given_rewards

        return next_obs, rewards, dones
示例#5
0
 def _get_model_plan_value(self, obs, plan):
     obs, plan = ptu.from_numpy(obs), ptu.from_numpy(plan)
     plans = plan.view(-1, self.horizon, self.plan_dim)
     plans = plans.permute(1, 0, 2)
     obs = obs.view(1, -1)
     returns = self.get_plan_values(obs, plans)
     return ptu.get_numpy(returns).mean()
示例#6
0
 def get_diagnostics(self):
     stats = OrderedDict()
     stats.update(
         create_stats_ordered_dict(
             'mean',
             ptu.get_numpy(self.mean),
         ))
     stats.update(
         create_stats_ordered_dict('normal/std',
                                   ptu.get_numpy(self.normal_std)))
     stats.update(
         create_stats_ordered_dict(
             'normal/log_std',
             ptu.get_numpy(torch.log(self.normal_std)),
         ))
     return stats
示例#7
0
    def get_plan_values_batch(self, obs, plans):
        """
        Get corresponding values of the plans (higher corresponds to better plans). Classes
        that don't want to plan over actions or use trajectory sampling can reimplement
        convert_plans_to_actions (& convert_plan_to_action) and/or predict_transition.
        plans is input as as torch (horizon_length, num_particles (total), plan_dim).
        We maintain trajectory infos as torch (n_part, info_dim (ex. obs_dim)).
        """

        if self.use_gt_model:
            return self.get_plan_values_batch_gt(obs, plans)

        n_part = plans.shape[
            1]  # *total* number of particles, NOT num_particles

        discount = 1
        returns, dones, infos = ptu.zeros(n_part), ptu.zeros(n_part), dict()

        # The effective planning horizon is self.horizon * self.repeat_length
        for t in range(self.horizon):
            for k in range(self.repeat_length):
                cur_actions = self.convert_plans_to_actions(obs, plans[t])
                obs, cur_rewards, cur_dones = self.predict_transition(
                    obs, cur_actions, infos)
                returns += discount * (1 - dones) * cur_rewards
                discount *= self.discount
                if self.predict_terminal:
                    dones = torch.max(dones, cur_dones.float())

        self.diagnostics.update(
            create_stats_ordered_dict(
                'MPC Termination',
                ptu.get_numpy(dones),
            ))

        if self.value_func is not None:
            terminal_values = self.value_func(
                obs, **self.value_func_kwargs).view(-1)
            returns += discount * (1 - dones) * terminal_values

            self.diagnostics.update(
                create_stats_ordered_dict(
                    'MPC Terminal Values',
                    ptu.get_numpy(terminal_values),
                ))

        return returns
示例#8
0
    def train_from_paths(self, paths):

        """
        Path processing
        """

        paths = copy.deepcopy(paths)
        for path in paths:
            obs, next_obs = path['observations'], path['next_observations']
            states, next_states = obs[:,:self.state_dim], next_obs[:,:self.state_dim]
            goals = obs[:,self.state_dim:2*self.state_dim]
            actions = path['actions']
            terminals = path['terminals']  # this is probably always False, but might want it?
            path_len = len(obs)

            # Relabel goals based on transitions taken
            relabeled_goals = []
            for t in range(len(obs)):
                relabeled_goals.append(self.relabel_goal_func(
                    states[t], actions[t], next_states[t], goals[t],
                ))
            relabeled_goals = np.array(relabeled_goals)

            # Add transitions & resampled goals to replay buffer
            for t in range(path_len):
                goals_t = goals[t:t+1]
                for _ in range(self.num_sampled_goals):
                    if self.relabel_method == 'future':
                        goal_inds = np.random.randint(t, path_len, self.num_sampled_goals)
                        goals_t = np.concatenate([goals_t, relabeled_goals[goal_inds]], axis=0)
                    else:
                        raise NotImplementedError

                for k in range(len(goals_t)):
                    if not self.learn_reward_func:
                        r = self.reward_func(states[t], actions[t], next_states[t], goals_t[k])
                    else:
                        r = ptu.get_numpy(
                            self.learned_reward_func(
                                ptu.from_numpy(
                                    np.concatenate([next_states[t], goals[t]])))).mean()
                    self.replay_buffer.add_sample(
                        observation=np.concatenate([states[t], goals_t[k], obs[t,2*self.state_dim:]]),
                        action=actions[t],
                        reward=r,
                        terminal=terminals[t],  # not obvious what desired behavior is
                        next_observation=np.concatenate(
                            [next_states[t,:self.state_dim], goals_t[k], obs[t,2*self.state_dim:]]),
                        env_info=None,
                    )

        """
        Off-policy training
        """

        for _ in range(self.num_policy_steps):
            train_data = self.replay_buffer.random_batch(self.policy_batch_size)
            self.policy_trainer.train(train_data)
示例#9
0
def calculate_baselines(paths, value_func):
    for path in paths:
        obs = ptu.from_numpy(
            np.concatenate(
                [path['observations'], path['next_observations'][-1:]],
                axis=0))
        values = torch.squeeze(value_func(obs), dim=-1)
        path['baselines'] = ptu.get_numpy(values)
        if path['terminals'][-1]:
            path['baselines'][-1] = 0
示例#10
0
 def generate_latents(self, obs):
     if self._train_calls < self.num_unif_train_calls:
         return super().generate_latents(obs)
     latents, *_ = self.skill_practice_dist(ptu.from_numpy(obs))
     latents = ptu.get_numpy(latents)
     if self.epsilon_greedy > 0:
         unif_r = np.random.uniform(0, 1, size=latents.shape[0])
         eps_replace = unif_r < self.epsilon_greedy
         unif_latents = super().generate_latents(obs[eps_replace])
         latents[eps_replace] = unif_latents
     return latents
示例#11
0
    def train_from_paths(self, paths, train_discrim=True, train_policy=True):

        """
        Reading new paths: append latent to state
        Note that is equivalent to on-policy when latent buffer size = sum of paths length
        """

        epoch_obs, epoch_next_obs, epoch_latents = [], [], []

        for path in paths:
            obs = path['observations']
            next_obs = path['next_observations']
            actions = path['actions']
            latents = path.get('latents', None)
            path_len = len(obs) - self.empowerment_horizon + 1

            obs_latents = np.concatenate([obs, latents], axis=-1)
            log_probs = self.control_policy.get_log_probs(
                ptu.from_numpy(obs_latents),
                ptu.from_numpy(actions),
            )
            log_probs = ptu.get_numpy(log_probs)

            for t in range(path_len):
                self.add_sample(
                    obs[t],
                    next_obs[t+self.empowerment_horizon-1],
                    next_obs[t],
                    actions[t],
                    latents[t],
                    logprob=log_probs[t],
                )

                epoch_obs.append(obs[t:t+1])
                epoch_next_obs.append(next_obs[t+self.empowerment_horizon-1:t+self.empowerment_horizon])
                epoch_latents.append(np.expand_dims(latents[t], axis=0))

        epoch_obs = np.concatenate(epoch_obs, axis=0)
        epoch_next_obs = np.concatenate(epoch_next_obs, axis=0)
        epoch_latents = np.concatenate(epoch_latents, axis=0)

        self._epoch_size = len(epoch_obs)

        gt.stamp('policy training', unique=False)

        """
        The rest is shared, train from buffer
        """

        if train_discrim:
            self.train_discriminator(epoch_obs, epoch_next_obs, epoch_latents)
        if train_policy:
            self.train_from_buffer()
示例#12
0
    def get_action(self, state):
        if (self._steps_since_last_sample >= self.steps_between_sampling
                or self._last_latent is None) and not self.fixed_latent:
            latent = self.sample_latent(state)
            self._steps_since_last_sample = 0
        else:
            latent = self._last_latent
        self._steps_since_last_sample += 1

        state = ptu.from_numpy(state)
        sz = torch.cat((state, latent))
        action, *_ = self.policy.forward(sz)
        return ptu.get_numpy(action), dict()
示例#13
0
    def train_discriminator(self, obs, next_obs, latents):

        obs_deltas = next_obs - obs

        self.discriminator.train()
        start_discrim_loss = None

        for i in range(self.num_discrim_updates):
            batch = ppp.sample_batch(
                self.policy_batch_size,
                obs=obs,
                latents=latents,
                obs_deltas=obs_deltas,
            )
            batch = ptu.np_to_pytorch_batch(batch)

            if self.restrict_input_size > 0:
                batch['obs'] = batch['obs'][:, :self.restrict_input_size]
                batch['obs_deltas'] = batch['obs_deltas'][:, :self.restrict_input_size]

            # we embedded the latent in the observation, so (s, z) -> (delta s')
            discrim_loss = self.discriminator.get_loss(
                batch['obs'],
                batch['latents'],
                batch['obs_deltas'],
            )

            if i == 0:
                start_discrim_loss = discrim_loss

            self.discrim_optim.zero_grad()
            discrim_loss.backward()
            self.discrim_optim.step()

        if self._need_to_update_eval_statistics:
            self.eval_statistics['Discriminator Loss'] = ptu.get_numpy(discrim_loss).mean()
            self.eval_statistics['Discriminator Start Loss'] = ptu.get_numpy(start_discrim_loss).mean()

        gt.stamp('discriminator training', unique=False)
def _create_paths(observations, actions, rewards, terminals, max_path_length):
    observations_np = ptu.get_numpy(observations)
    actions_np = ptu.get_numpy(actions)
    rewards_np = ptu.get_numpy(rewards)
    terminals_np = ptu.get_numpy(terminals)

    paths = []
    for i in range(len(observations)):
        rollout_len = 1
        while rollout_len < max_path_length and terminals[
                i, rollout_len - 1, 0] < 0.5:  # just check 0 or 1
            rollout_len += 1
        paths.append(
            dict(
                observations=observations_np[i, :rollout_len],
                actions=actions_np[i, :rollout_len],
                rewards=rewards_np[i, :rollout_len],
                next_observations=observations_np[i, 1:rollout_len + 1],
                terminals=terminals_np[i, :rollout_len],
                agent_infos=[[] for _ in range(rollout_len)],
                env_infos=[[] for _ in range(rollout_len)],
            ))
    return paths
示例#15
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = -q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
        self._n_train_steps_total += 1
示例#16
0
    def sample_paths(self, start_states, rollout_len):
        if self.sampling_mode == 'uniform':
            # Sample uniformly from a model of the ensemble (original MBPO; Janner et al. 2019)
            paths = mrf.policy(
                self.dynamics_model,
                self.policy_trainer.policy,
                start_states,
                max_path_length=rollout_len,
            )

        elif self.sampling_mode == 'mean_disagreement':
            # Sample with penalty for disagreement of the mean (MOReL; Kidambi et al. 2020)
            paths, disagreements = mrf.policy_with_disagreement(
                self.dynamics_model,
                self.policy_trainer.policy,
                start_states,
                max_path_length=rollout_len,
                disagreement_type='mean',
            )
            disagreements = ptu.get_numpy(disagreements)

            threshold, penalty = self.sampling_kwargs[
                'threshold'], self.sampling_kwargs['penalty']
            total_penalized, total_transitions = 0, 0
            for i, path in enumerate(paths):
                mask = np.zeros(len(path['rewards']))
                disagreement_values = disagreements[i]
                for t in range(len(path['rewards'])):
                    cur_mask = disagreement_values[t] > threshold
                    if t == 0:
                        mask[t] = cur_mask
                    elif cur_mask or mask[t - 1] > 0.5:
                        mask[t] = 1.
                    else:
                        mask[t] = 0.
                mask = mask.reshape(len(mask), 1)
                path['rewards'] = (1 - mask) * path['rewards'] - mask * penalty
                total_penalized += mask.sum()
                total_transitions += len(path)

            self.eval_statistics[
                'Percent of Transitions Penalized'] = total_penalized / total_transitions
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Disagreement Values',
                    disagreements,
                ))

        elif self.sampling_mode == 'var_disagreement':
            # Sample with penalty for disagreement of the variance (MOPO; Yu et al. 2020)
            paths, disagreements = mrf.policy_with_disagreement(
                self.dynamics_model,
                self.policy_trainer.policy,
                start_states,
                max_path_length=rollout_len,
                disagreement_type='var',
            )
            disagreements = ptu.get_numpy(disagreements)

            reward_penalty = self.sampling_kwargs['reward_penalty']
            for i, path in enumerate(paths):
                path_disagreements = disagreements[
                    i, :len(path['rewards'])].reshape(*path['rewards'].shape)
                path['rewards'] -= reward_penalty * path_disagreements

            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Disagreement Values',
                    disagreements,
                ))

        else:
            raise NotImplementedError

        return paths
示例#17
0
 def get_current_latent(self):
     return ptu.get_numpy(self._last_latent)
示例#18
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy operations.
        """
        if self.policy_pre_activation_weight > 0:
            policy_actions, pre_tanh_value = self.policy(
                obs,
                return_preactivations=True,
            )
            pre_activation_policy_loss = ((pre_tanh_value**2).sum(
                dim=1).mean())
            q_output = self.qf(obs, policy_actions)
            raw_policy_loss = -q_output.mean()
            policy_loss = (
                raw_policy_loss +
                pre_activation_policy_loss * self.policy_pre_activation_weight)
        else:
            policy_actions = self.policy(obs)
            q_output = self.qf(obs, policy_actions)
            raw_policy_loss = policy_loss = -q_output.mean()
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        # speed up computation by not backpropping these gradients
        next_actions.detach()
        target_q_values = self.target_qf(
            next_obs,
            next_actions,
        )
        q_target = rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
        q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)
        q_pred = self.qf(obs, actions)
        bellman_errors = (q_pred - q_target)**2
        raw_qf_loss = self.qf_criterion(q_pred, q_target)

        if self.qf_weight_decay > 0:
            reg_loss = self.qf_weight_decay * sum(
                torch.sum(param**2)
                for param in self.qf.regularizable_parameters())
            qf_loss = raw_qf_loss + reg_loss
        else:
            qf_loss = raw_qf_loss
        """
        Update Networks
        """

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self._update_target_networks()
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics['Raw Policy Loss'] = np.mean(
                ptu.get_numpy(raw_policy_loss))
            self.eval_statistics['Preactivation Policy Loss'] = (
                self.eval_statistics['Policy Loss'] -
                self.eval_statistics['Raw Policy Loss'])
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors',
                    ptu.get_numpy(bellman_errors),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
        self._n_train_steps_total += 1
def calculate_contrastive_empowerment(
    discriminator,
    obs,
    next_obs,
    latents,
    num_prior_samples=512,
    distribution_type='uniform',
    split_group=4096 * 32,
    obs_mean=None,
    obs_std=None,
    return_diagnostics=False,
    prior=None,
):
    """
    Described in Sharma et al 2019.
    Approximate variational lower bound using estimate of s' from s, z.
    Uses contrastive negatives to approximate denominator.
    """

    discriminator.eval()

    if obs_mean is not None:
        obs = (obs - obs_mean) / (obs_std + 1e-6)
        next_obs = (next_obs - obs_mean) / (obs_std + 1e-6)

    obs_deltas = ptu.from_numpy(next_obs - obs)
    obs_altz = np.concatenate([obs] * num_prior_samples, axis=0)

    with torch.no_grad():
        logp = discriminator.get_log_prob(
            ptu.from_numpy(obs),
            ptu.from_numpy(latents),
            obs_deltas,
        )
        logp = ptu.get_numpy(logp)

    if distribution_type == 'uniform':
        latent_altz = np.random.uniform(low=-1,
                                        high=1,
                                        size=(obs_altz.shape[0],
                                              latents.shape[1]))
    elif distribution_type == 'prior':
        if prior is None:
            raise AssertionError('prior specified but not passed in')
        obs_t = ptu.from_numpy(obs_altz)
        latent_altz, *_ = prior.get_action(obs_t, deterministic=False)
    else:
        raise NotImplementedError('distribution_type not found')

    # keep track of next obs/delta
    next_obs_altz = np.concatenate([next_obs - obs] * num_prior_samples,
                                   axis=0)

    with torch.no_grad():
        if obs_altz.shape[0] <= split_group:
            logp_altz = ptu.get_numpy(
                discriminator.get_log_prob(
                    ptu.from_numpy(obs_altz),
                    ptu.from_numpy(latent_altz),
                    ptu.from_numpy(next_obs_altz),
                ))
        else:
            logp_altz = []
            for split_idx in range(obs_altz.shape[0] // split_group):
                start_split = split_idx * split_group
                end_split = (split_idx + 1) * split_group
                logp_altz.append(
                    ptu.get_numpy(
                        discriminator.get_log_prob(
                            ptu.from_numpy(obs_altz[start_split:end_split]),
                            ptu.from_numpy(latent_altz[start_split:end_split]),
                            ptu.from_numpy(
                                next_obs_altz[start_split:end_split]),
                        )))
            if obs_altz.shape[0] % split_group:
                start_split = obs_altz.shape[0] % split_group
                logp_altz.append(
                    ptu.get_numpy(
                        discriminator.get_log_prob(
                            ptu.from_numpy(obs_altz[-start_split:]),
                            ptu.from_numpy(latent_altz[-start_split:]),
                            ptu.from_numpy(next_obs_altz[-start_split:]),
                        )))
            logp_altz = np.concatenate(logp_altz)
    logp_altz = np.array(np.array_split(logp_altz, num_prior_samples))

    if return_diagnostics:
        diagnostics = dict()
        orig_rep = np.repeat(np.expand_dims(logp, axis=0),
                             axis=0,
                             repeats=num_prior_samples)
        diagnostics['Pct Random Skills > Original'] = (orig_rep <
                                                       logp_altz).mean()

    # final DADS reward
    intrinsic_reward = np.log(num_prior_samples + 1) - np.log(1 + np.exp(
        np.clip(logp_altz - logp.reshape(1, -1), -50, 50)).sum(axis=0))

    if not return_diagnostics:
        return intrinsic_reward, (logp, logp_altz, logp - intrinsic_reward)
    else:
        return intrinsic_reward, (logp, logp_altz,
                                  logp - intrinsic_reward), diagnostics
示例#20
0
    def train_from_buffer(self,
                          replay_buffer,
                          holdout_pct=0.2,
                          max_grad_steps=1000,
                          epochs_since_last_update=5):
        self._n_train_steps_total += 1
        if self._n_train_steps_total % self.train_call_freq > 0 and self._n_train_steps_total > 1:
            return

        data = replay_buffer.get_transitions()
        x = data[:, :self.obs_dim + self.action_dim]  # inputs  s, a
        y = data[:, self.obs_dim + self.action_dim:]  # predict r, d, ns
        y[:,
          -self.obs_dim:] -= x[:, :self.obs_dim]  # predict delta in the state

        # normalize network inputs
        self.ensemble.fit_input_stats(x)

        # generate holdout set
        inds = np.random.permutation(data.shape[0])
        x, y = x[inds], y[inds]

        n_train = max(int((1 - holdout_pct) * data.shape[0]),
                      data.shape[0] - 8092)
        n_test = data.shape[0] - n_train

        x_train, y_train = x[:n_train], y[:n_train]
        x_test, y_test = x[n_train:], y[n_train:]
        x_test, y_test = ptu.from_numpy(x_test), ptu.from_numpy(y_test)

        # train until holdout set convergence
        num_epochs, num_steps = 0, 0
        num_epochs_since_last_update = 0
        best_holdout_loss = float('inf')
        num_batches = int(np.ceil(n_train / self.batch_size))

        while num_epochs_since_last_update < epochs_since_last_update and num_steps < max_grad_steps:
            # generate idx for each model to bootstrap
            self.ensemble.train()
            for b in range(num_batches):
                b_idxs = np.random.randint(n_train,
                                           size=(self.ensemble_size *
                                                 self.batch_size))
                x_batch, y_batch = x_train[b_idxs], y_train[b_idxs]
                x_batch, y_batch = ptu.from_numpy(x_batch), ptu.from_numpy(
                    y_batch)
                x_batch = x_batch.view(self.ensemble_size, self.batch_size, -1)
                y_batch = y_batch.view(self.ensemble_size, self.batch_size, -1)
                loss = self.ensemble.get_loss(x_batch, y_batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            num_steps += num_batches

            # stop training based on holdout loss improvement
            self.ensemble.eval()
            with torch.no_grad():
                holdout_losses, holdout_errors = self.ensemble.get_loss(
                    x_test, y_test, split_by_model=True, return_l2_error=True)
            holdout_loss = sum(
                sorted(holdout_losses)[:self.num_elites]) / self.num_elites

            if num_epochs == 0 or \
               (best_holdout_loss - holdout_loss) / abs(best_holdout_loss) > 0.01:
                best_holdout_loss = holdout_loss
                num_epochs_since_last_update = 0
            else:
                num_epochs_since_last_update += 1

            num_epochs += 1

        self.ensemble.elites = np.argsort(holdout_losses)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['Model Elites Holdout Loss'] = \
                np.mean(ptu.get_numpy(holdout_loss))
            self.eval_statistics['Model Holdout Loss'] = \
                np.mean(ptu.get_numpy(sum(holdout_losses))) / self.ensemble_size
            self.eval_statistics['Model Training Epochs'] = num_epochs
            self.eval_statistics['Model Training Steps'] = num_steps

            for i in range(self.ensemble_size):
                name = 'M%d' % (i + 1)
                self.eval_statistics[name + ' Loss'] = \
                    np.mean(ptu.get_numpy(holdout_losses[i]))
                self.eval_statistics[name + ' L2 Error'] = \
                    np.mean(ptu.get_numpy(holdout_errors[i]))
示例#21
0
    def train_from_torch(self, batch):

        self._train_calls += 1
        if self._train_calls % self.train_every > 0:
            return

        rollout_len = self.rollout_len_func(self._n_train_steps_total)
        num_model_rollouts = max(self.num_model_samples // rollout_len, 1)
        self.eval_statistics['Rollout Length'] = rollout_len

        real_batch = self.replay_buffer.random_batch(num_model_rollouts)
        start_states = real_batch['observations']
        latents = self.generate_latents(start_states)

        observations = np.zeros((self.num_model_samples, self.obs_dim))
        next_observations = np.zeros((self.num_model_samples, self.obs_dim))
        actions = np.zeros((self.num_model_samples, self.action_dim))
        unfolded_latents = np.zeros((self.num_model_samples, self.latent_dim))
        disagreements = np.zeros(self.num_model_samples)

        num_samples, b_ind, num_traj = 0, 0, 0
        while num_samples < self.num_model_samples:
            e_ind = b_ind + 4192 // rollout_len
            with torch.no_grad():
                paths, path_disagreements = self.generate_paths(
                    dynamics_model=self.dynamics_model,
                    control_policy=self.control_policy,
                    start_states=start_states[b_ind:e_ind],
                    latents=ptu.from_numpy(latents[b_ind:e_ind]),
                    rollout_len=rollout_len,
                )

            b_ind = e_ind

            path_disagreements = ptu.get_numpy(path_disagreements)
            for i, path in enumerate(paths):
                clipped_len = min(
                    len(path['observations'] - (self.empowerment_horizon - 1)),
                    self.num_model_samples - num_samples)
                bi, ei = num_samples, num_samples + clipped_len

                if self.empowerment_horizon > 1:
                    path['observations'] = path['observations'][:-(
                        self.empowerment_horizon - 1)]
                    path['next_observations'] = path['next_observations'][(
                        self.empowerment_horizon -
                        1):(self.empowerment_horizon - 1) + clipped_len]
                    path['actions'] = path['actions'][:-(
                        self.empowerment_horizon - 1)]

                observations[bi:ei] = path['observations'][:clipped_len]
                next_observations[bi:ei] = path[
                    'next_observations'][:clipped_len]
                actions[bi:ei] = path['actions'][:clipped_len]
                unfolded_latents[bi:ei] = latents[num_traj:num_traj + 1]
                disagreements[bi:ei] = path_disagreements[i, :clipped_len]

                num_samples += clipped_len
                num_traj += 1

                if num_samples >= self.num_model_samples:
                    break

        gt.stamp('generating rollouts', unique=False)

        if not self.relabel_rewards:
            rewards, (
                logp, logp_altz,
                denom), reward_diagnostics = self.calculate_intrinsic_rewards(
                    observations, next_observations, unfolded_latents)
            orig_rewards = rewards.copy()
            rewards, postproc_dict = self.reward_postprocessing(
                rewards, reward_kwargs=dict(disagreements=disagreements))
            reward_diagnostics.update(postproc_dict)

            if self._need_to_update_eval_statistics:
                self.eval_statistics.update(reward_diagnostics)

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Log Pis',
                        logp,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Alt Log Pis',
                        logp_altz,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Reward Denominator',
                        denom,
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Original)',
                        orig_rewards,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Processed)',
                        rewards,
                    ))

            gt.stamp('intrinsic reward calculation', unique=False)

        if self._need_to_update_eval_statistics:
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Latents',
                    latents,
                ))

        for t in range(self.num_model_samples):
            self.add_sample(
                observations[t],
                next_observations[t],
                next_observations[t],  # fix this
                actions[t],
                unfolded_latents[t],
                disagreement=disagreements[t],
            )

        gt.stamp('policy training', unique=False)

        self.train_discriminator(observations, next_observations,
                                 unfolded_latents)

        reward_kwargs = dict(
            disagreements=self._modeL_disagreements[:self._cur_replay_size])
        self.train_from_buffer(reward_kwargs=reward_kwargs)
示例#22
0
    def train_from_paths(self, paths):
        """
        Path preprocessing; have to copy so we don't modify when paths are used elsewhere
        """

        paths = copy.deepcopy(paths)
        for path in paths:
            # Other places like to have an extra dimension so that all arrays are 2D
            path['terminals'] = np.squeeze(path['terminals'], axis=-1)
            path['rewards'] = np.squeeze(path['rewards'], axis=-1)

            # Reward normalization; divide by std of reward in replay buffer
            path['rewards'] = np.clip(
                path['rewards'] / (self._reward_std + 1e-3), -10, 10)

        obs, actions = [], []
        for path in paths:
            obs.append(path['observations'])
            actions.append(path['actions'])
        obs = np.concatenate(obs, axis=0)
        actions = np.concatenate(actions, axis=0)

        obs_tensor, act_tensor = ptu.from_numpy(obs), ptu.from_numpy(actions)
        """
        Policy training loop
        """

        old_policy = copy.deepcopy(self.policy)
        with torch.no_grad():
            log_probs_old = old_policy.get_log_probs(
                obs_tensor, act_tensor).squeeze(dim=-1)

        rem_value_epochs = self.num_epochs
        for epoch in range(self.num_policy_epochs):
            """
            Recompute advantages at the beginning of each epoch. This allows for advantages
                to utilize the latest value function.
            Note: while this is not present in most implementations, it is recommended
                  by Andrychowicz et al. 2020.
            """

            path_functions.calculate_baselines(paths, self.value_func)
            path_functions.calculate_returns(paths, self.discount)
            path_functions.calculate_advantages(
                paths,
                self.discount,
                self.gae_lambda,
                self.normalize_advantages,
            )

            advantages, returns, baselines = [], [], []
            for path in paths:
                advantages = np.append(advantages, path['advantages'])
                returns = np.append(returns, path['returns'])

            if epoch == 0 and self._need_to_update_eval_statistics:
                with torch.no_grad():
                    values = torch.squeeze(self.value_func(obs_tensor), dim=-1)
                    values_np = ptu.get_numpy(values)
                first_val_loss = ((returns - values_np)**2).mean()

            old_params = self.policy.get_param_values()

            num_policy_steps = len(advantages) // self.policy_batch_size
            for _ in range(num_policy_steps):
                if num_policy_steps == 1:
                    batch = dict(
                        observations=obs,
                        actions=actions,
                        advantages=advantages,
                    )
                else:
                    batch = ppp.sample_batch(
                        self.policy_batch_size,
                        observations=obs,
                        actions=actions,
                        advantages=advantages,
                    )
                policy_loss, kl = self.train_policy(batch, old_policy)

            with torch.no_grad():
                log_probs = self.policy.get_log_probs(
                    obs_tensor, act_tensor).squeeze(dim=-1)
            kl = (log_probs_old - log_probs).mean()

            if (self.target_kl is not None
                    and kl > 1.5 * self.target_kl) or (kl != kl):
                if epoch > 0 or kl != kl:  # nan check
                    self.policy.set_param_values(old_params)
                break

            num_value_steps = len(advantages) // self.value_batch_size
            for i in range(num_value_steps):
                batch = ppp.sample_batch(
                    self.value_batch_size,
                    observations=obs,
                    targets=returns,
                )
                value_loss = self.train_value(batch)
            rem_value_epochs -= 1

        # Ensure the value function is always updated for the maximum number
        # of epochs, regardless of if the policy wants to terminate early.
        for _ in range(rem_value_epochs):
            num_value_steps = len(advantages) // self.value_batch_size
            for i in range(num_value_steps):
                batch = ppp.sample_batch(
                    self.value_batch_size,
                    observations=obs,
                    targets=returns,
                )
                value_loss = self.train_value(batch)

        if self._need_to_update_eval_statistics:
            with torch.no_grad():
                _, _, _, log_pi, *_ = self.policy(obs_tensor,
                                                  return_log_prob=True)
                values = torch.squeeze(self.value_func(obs_tensor), dim=-1)
                values_np = ptu.get_numpy(values)

            errors = returns - values_np
            explained_variance = 1 - (np.var(errors) / np.var(returns))
            value_loss = errors**2

            self.eval_statistics['Num Epochs'] = epoch + 1

            self.eval_statistics['Policy Loss'] = ptu.get_numpy(
                policy_loss).mean()
            self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl).mean()
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantages',
                    advantages,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Returns',
                    returns,
                ))

            self.eval_statistics['Value Loss'] = value_loss.mean()
            self.eval_statistics['First Value Loss'] = first_val_loss
            self.eval_statistics[
                'Value Explained Variance'] = explained_variance
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Values',
                    ptu.get_numpy(values),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Value Squared Errors',
                    value_loss,
                ))
示例#23
0
    def train_from_torch(self, batch):
        obs = batch['observations']
        next_obs = batch['next_observations']
        actions = batch['actions']
        rewards = batch['rewards']
        terminals = batch.get('terminals', ptu.zeros(rewards.shape[0], 1))

        """
        Policy and Alpha Loss
        """
        _, policy_mean, policy_logstd, *_ = self.policy(obs)
        dist = TanhNormal(policy_mean, policy_logstd.exp())
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.sum(dim=-1, keepdims=True)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        _, next_policy_mean, next_policy_logstd, *_ = self.policy(next_obs)
        next_dist = TanhNormal(next_policy_mean, next_policy_logstd.exp())
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.sum(dim=-1, keepdims=True)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        future_values = (1. - terminals) * self.discount * target_q_values
        q_target = self.reward_scale * rewards + future_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self._n_train_steps_total += 1

        self.try_update_target_networks()

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            policy_loss = (log_pi - q_new_actions).mean()
            policy_avg_std = torch.exp(policy_logstd).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_logstd),
            ))
            self.eval_statistics['Policy std'] = np.mean(ptu.get_numpy(policy_avg_std))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

        self._n_train_steps_total += 1