Esempio n. 1
0
    def create_particles(self, plans, obs, n_part=None):
        n_opt = plans.shape[0]
        if n_part is None:
            n_part = self.num_particles * self.num_models

        # (N, H*m)
        plans = ptu.from_numpy(plans)
        # (N, H, m)
        plans = plans.view(-1, self.horizon, self.plan_dim)
        # (H, N, m)
        transposed = plans.transpose(0, 1)
        # (H, N, 1, m)
        expanded = transposed[:, :, None]
        # (H, N, P, m)
        tiled = expanded.expand(-1, -1, n_part, -1)
        # (H, N*P, m)
        plans = tiled.contiguous().view(self.horizon, -1, self.plan_dim)

        # (n,)
        obs = ptu.from_numpy(self._observation)
        # (1, n)
        obs = obs[None]
        # (N*P, n)
        obs = obs.expand(n_opt * n_part, -1)

        return plans, obs
Esempio n. 2
0
 def _get_model_plan_value(self, obs, plan):
     obs, plan = ptu.from_numpy(obs), ptu.from_numpy(plan)
     plans = plan.view(-1, self.horizon, self.plan_dim)
     plans = plans.permute(1, 0, 2)
     obs = obs.view(1, -1)
     returns = self.get_plan_values(obs, plans)
     return ptu.get_numpy(returns).mean()
Esempio n. 3
0
 def fit_input_stats(self, data, mask=None):
     mean = np.mean(data, axis=0, keepdims=True)
     std = np.std(data, axis=0, keepdims=True)
     std[std != std] = 0
     std[std < 1e-12] = 1.0
     if mask is not None:
         mean *= mask
         std = mask * std + (1-mask) * np.ones(self.input_size)
     self.input_mu.data = ptu.from_numpy(mean)
     self.input_std.data = ptu.from_numpy(std)
Esempio n. 4
0
    def train_from_paths(self, paths, train_discrim=True, train_policy=True):

        """
        Reading new paths: append latent to state
        Note that is equivalent to on-policy when latent buffer size = sum of paths length
        """

        epoch_obs, epoch_next_obs, epoch_latents = [], [], []

        for path in paths:
            obs = path['observations']
            next_obs = path['next_observations']
            actions = path['actions']
            latents = path.get('latents', None)
            path_len = len(obs) - self.empowerment_horizon + 1

            obs_latents = np.concatenate([obs, latents], axis=-1)
            log_probs = self.control_policy.get_log_probs(
                ptu.from_numpy(obs_latents),
                ptu.from_numpy(actions),
            )
            log_probs = ptu.get_numpy(log_probs)

            for t in range(path_len):
                self.add_sample(
                    obs[t],
                    next_obs[t+self.empowerment_horizon-1],
                    next_obs[t],
                    actions[t],
                    latents[t],
                    logprob=log_probs[t],
                )

                epoch_obs.append(obs[t:t+1])
                epoch_next_obs.append(next_obs[t+self.empowerment_horizon-1:t+self.empowerment_horizon])
                epoch_latents.append(np.expand_dims(latents[t], axis=0))

        epoch_obs = np.concatenate(epoch_obs, axis=0)
        epoch_next_obs = np.concatenate(epoch_next_obs, axis=0)
        epoch_latents = np.concatenate(epoch_latents, axis=0)

        self._epoch_size = len(epoch_obs)

        gt.stamp('policy training', unique=False)

        """
        The rest is shared, train from buffer
        """

        if train_discrim:
            self.train_discriminator(epoch_obs, epoch_next_obs, epoch_latents)
        if train_policy:
            self.train_from_buffer()
Esempio n. 5
0
    def train_value(self, batch):
        obs = ptu.from_numpy(batch['observations'])
        targets = ptu.from_numpy(batch['targets'])

        value_preds = torch.squeeze(self.value_func(obs), dim=-1)
        value_loss = 0.5 * ((value_preds - targets)**2).mean()

        self.value_optim.zero_grad()
        value_loss.backward()
        self.value_optim.step()

        return value_loss
Esempio n. 6
0
 def sample_latent(self, state=None):
     if self.unconditional or state is None:  # this will probably be changed
         latent = self.prior.sample()  # n=1).squeeze(0)
     else:
         latent = self.prior.forward(ptu.from_numpy(state))
     self.set_latent(latent)
     return latent
Esempio n. 7
0
    def train_policy(self, batch, old_policy):
        obs = ptu.from_numpy(batch['observations'])
        actions = ptu.from_numpy(batch['actions'])
        advantages = ptu.from_numpy(batch['advantages'])

        objective, kl = self.policy_objective(obs, actions, advantages,
                                              old_policy)
        policy_loss = -objective

        self.policy_optim.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(),
                                       self.max_grad_norm)
        self.policy_optim.step()

        return policy_loss, kl
def _create_full_tensors(start_states, max_path_length, obs_dim, action_dim):
    num_rollouts = start_states.shape[0]
    observations = ptu.zeros((num_rollouts, max_path_length + 1, obs_dim))
    observations[:, 0] = ptu.from_numpy(start_states)
    actions = ptu.zeros((num_rollouts, max_path_length, action_dim))
    rewards = ptu.zeros((num_rollouts, max_path_length, 1))
    terminals = ptu.zeros((num_rollouts, max_path_length, 1))
    return observations, actions, rewards, terminals
Esempio n. 9
0
 def set_param_values(self, new_params):
     current_idx = 0
     for idx, param in enumerate(self.trainable_params):
         vals = new_params[current_idx:current_idx + self.param_sizes[idx]]
         vals = vals.reshape(self.param_shapes[idx])
         param.data = ptu.from_numpy(vals).float()
         current_idx += self.param_sizes[idx]
     self.trainable_params[-1].data = torch.clamp(self.trainable_params[-1], LOG_SIG_MIN)
Esempio n. 10
0
    def train_from_paths(self, paths):

        """
        Path processing
        """

        paths = copy.deepcopy(paths)
        for path in paths:
            obs, next_obs = path['observations'], path['next_observations']
            states, next_states = obs[:,:self.state_dim], next_obs[:,:self.state_dim]
            goals = obs[:,self.state_dim:2*self.state_dim]
            actions = path['actions']
            terminals = path['terminals']  # this is probably always False, but might want it?
            path_len = len(obs)

            # Relabel goals based on transitions taken
            relabeled_goals = []
            for t in range(len(obs)):
                relabeled_goals.append(self.relabel_goal_func(
                    states[t], actions[t], next_states[t], goals[t],
                ))
            relabeled_goals = np.array(relabeled_goals)

            # Add transitions & resampled goals to replay buffer
            for t in range(path_len):
                goals_t = goals[t:t+1]
                for _ in range(self.num_sampled_goals):
                    if self.relabel_method == 'future':
                        goal_inds = np.random.randint(t, path_len, self.num_sampled_goals)
                        goals_t = np.concatenate([goals_t, relabeled_goals[goal_inds]], axis=0)
                    else:
                        raise NotImplementedError

                for k in range(len(goals_t)):
                    if not self.learn_reward_func:
                        r = self.reward_func(states[t], actions[t], next_states[t], goals_t[k])
                    else:
                        r = ptu.get_numpy(
                            self.learned_reward_func(
                                ptu.from_numpy(
                                    np.concatenate([next_states[t], goals[t]])))).mean()
                    self.replay_buffer.add_sample(
                        observation=np.concatenate([states[t], goals_t[k], obs[t,2*self.state_dim:]]),
                        action=actions[t],
                        reward=r,
                        terminal=terminals[t],  # not obvious what desired behavior is
                        next_observation=np.concatenate(
                            [next_states[t,:self.state_dim], goals_t[k], obs[t,2*self.state_dim:]]),
                        env_info=None,
                    )

        """
        Off-policy training
        """

        for _ in range(self.num_policy_steps):
            train_data = self.replay_buffer.random_batch(self.policy_batch_size)
            self.policy_trainer.train(train_data)
Esempio n. 11
0
def calculate_baselines(paths, value_func):
    for path in paths:
        obs = ptu.from_numpy(
            np.concatenate(
                [path['observations'], path['next_observations'][-1:]],
                axis=0))
        values = torch.squeeze(value_func(obs), dim=-1)
        path['baselines'] = ptu.get_numpy(values)
        if path['terminals'][-1]:
            path['baselines'][-1] = 0
Esempio n. 12
0
 def generate_latents(self, obs):
     if self._train_calls < self.num_unif_train_calls:
         return super().generate_latents(obs)
     latents, *_ = self.skill_practice_dist(ptu.from_numpy(obs))
     latents = ptu.get_numpy(latents)
     if self.epsilon_greedy > 0:
         unif_r = np.random.uniform(0, 1, size=latents.shape[0])
         eps_replace = unif_r < self.epsilon_greedy
         unif_latents = super().generate_latents(obs[eps_replace])
         latents[eps_replace] = unif_latents
     return latents
Esempio n. 13
0
    def train_policy(self, batch, old_policy):
        obs = ptu.from_numpy(batch['observations'])
        actions = ptu.from_numpy(batch['actions'])
        advantages = ptu.from_numpy(batch['advantages'])

        log_probs = torch.squeeze(self.policy.get_log_probs(obs, actions), dim=-1)
        log_probs_old = torch.squeeze(old_policy.get_log_probs(obs, actions), dim=-1)
        kl = (log_probs_old - log_probs).mean()

        vpg_grad, cpi_surr = self.flat_vpg(obs, actions, advantages, old_policy)
        hvp = self.build_Hvp_eval([obs, actions, old_policy], regu_coef=self.FIM_invert_args['damping'])
        npg_grad = cg_solve(hvp, vpg_grad, x_0=vpg_grad.copy(), cg_iters=self.FIM_invert_args['iters'])

        alpha = np.sqrt(np.abs(self.normalized_step_size / (np.dot(vpg_grad.T, npg_grad) + 1e-20)))

        cur_params = self.policy.get_param_values()
        new_params = cur_params + alpha * npg_grad
        self.policy.set_param_values(new_params)

        return -cpi_surr, kl
Esempio n. 14
0
 def get_plan_values_batch_gt(self, obs, plans):
     returns = ptu.zeros(plans.shape[1])
     obs, plans = ptu.get_numpy(obs), ptu.get_numpy(plans)
     final_obs = np.copy(obs)
     for i in range(plans.shape[1]):
         returns[i], final_obs[i] = self._get_true_env_value(
             obs[i], plans[:, i])
     if self.value_func is not None:
         returns += (self.discount**(
             self.horizon * self.repeat_length)) * (self.value_func(
                 ptu.from_numpy(final_obs), **self.value_func_kwargs))
     return returns
Esempio n. 15
0
    def get_action(self, state):
        if (self._steps_since_last_sample >= self.steps_between_sampling
                or self._last_latent is None) and not self.fixed_latent:
            latent = self.sample_latent(state)
            self._steps_since_last_sample = 0
        else:
            latent = self._last_latent
        self._steps_since_last_sample += 1

        state = ptu.from_numpy(state)
        sz = torch.cat((state, latent))
        action, *_ = self.policy.forward(sz)
        return ptu.get_numpy(action), dict()
Esempio n. 16
0
 def HVP(self, observations, actions, old_policy, vector, regu_coef=None):
     regu_coef = self.FIM_invert_args['damping'] if regu_coef is None else regu_coef
     vec = torch.autograd.Variable(ptu.from_numpy(vector).float(), requires_grad=False)
     if self.hvp_sample_frac is not None and self.hvp_sample_frac < 0.99:
         num_samples = observations.shape[0]
         rand_idx = np.random.choice(num_samples, size=int(self.hvp_sample_frac*num_samples))
         obs = observations[rand_idx]
         act = actions[rand_idx]
     else:
         obs = observations
         act = actions
     log_probs = torch.squeeze(self.policy.get_log_probs(obs, act), dim=-1)
     log_probs_old = torch.squeeze(old_policy.get_log_probs(obs, act), dim=-1)
     mean_kl = (log_probs_old - log_probs).mean()
     grad_fo = torch.autograd.grad(mean_kl, self.policy.trainable_params, create_graph=True)
     flat_grad = torch.cat([g.contiguous().view(-1) for g in grad_fo])
     h = torch.sum(flat_grad*vec)
     hvp = torch.autograd.grad(h, self.policy.trainable_params)
     hvp_flat = np.concatenate([g.contiguous().view(-1).cpu().data.numpy() for g in hvp])
     return hvp_flat + regu_coef * vector
Esempio n. 17
0
 def _log_prob_from_pre_tanh(self, pre_tanh_value):
     """
     Adapted from
     https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73
     This formula is mathematically equivalent to log(1 - tanh(x)^2).
     Derivation:
     log(1 - tanh(x)^2)
      = log(sech(x)^2)
      = 2 * log(sech(x))
      = 2 * log(2e^-x / (e^-2x + 1))
      = 2 * (log(2) - x - log(e^-2x + 1))
      = 2 * (log(2) - x - softplus(-2x))
     :param value: some value, x
     :param pre_tanh_value: arctanh(x)
     :return:
     """
     log_prob = self.normal.log_prob(pre_tanh_value)
     correction = -2. * (ptu.from_numpy(np.log([2.])) - pre_tanh_value -
                         torch.nn.functional.softplus(-2. * pre_tanh_value))
     return log_prob + correction
Esempio n. 18
0
def create_mask(inverse_beta_func, n_quantiles, risk_kwargs):
    """
    x in [0, 1] represents the CDF of the input.
    beta(x) represents the cumulative weight assigned to the lower x% of
        values, e.g. it is analogous to the CDF. This is typically easier
        to represent via the inverse of the beta function, so we take the
        inverse of the inverse beta function to get the original function.
    The reweighted function becomes:
        R(f, beta) = sum_i f(i/n) * (beta((i+1)/(n+1)) - beta(i/(n+1))
    """

    tau = np.linspace(0, 1, n_quantiles + 1)
    betas = np.zeros(n_quantiles + 1)
    mask = np.zeros(n_quantiles)

    # TODO: there are some issues with mask and risk_kwarg caching

    for i in range(n_quantiles + 1):
        betas[i] = inverse_beta_func(tau[i], risk_kwargs)
    for i in range(n_quantiles):
        mask[i] = betas[i + 1] - betas[i]

    return ptu.from_numpy(mask)
Esempio n. 19
0
    def train_from_paths(self, paths):
        """
        Path preprocessing; have to copy so we don't modify when paths are used elsewhere
        """

        paths = copy.deepcopy(paths)
        for path in paths:
            # Other places like to have an extra dimension so that all arrays are 2D
            path['terminals'] = np.squeeze(path['terminals'], axis=-1)
            path['rewards'] = np.squeeze(path['rewards'], axis=-1)

            # Reward normalization; divide by std of reward in replay buffer
            path['rewards'] = np.clip(
                path['rewards'] / (self._reward_std + 1e-3), -10, 10)

        obs, actions = [], []
        for path in paths:
            obs.append(path['observations'])
            actions.append(path['actions'])
        obs = np.concatenate(obs, axis=0)
        actions = np.concatenate(actions, axis=0)

        obs_tensor, act_tensor = ptu.from_numpy(obs), ptu.from_numpy(actions)
        """
        Policy training loop
        """

        old_policy = copy.deepcopy(self.policy)
        with torch.no_grad():
            log_probs_old = old_policy.get_log_probs(
                obs_tensor, act_tensor).squeeze(dim=-1)

        rem_value_epochs = self.num_epochs
        for epoch in range(self.num_policy_epochs):
            """
            Recompute advantages at the beginning of each epoch. This allows for advantages
                to utilize the latest value function.
            Note: while this is not present in most implementations, it is recommended
                  by Andrychowicz et al. 2020.
            """

            path_functions.calculate_baselines(paths, self.value_func)
            path_functions.calculate_returns(paths, self.discount)
            path_functions.calculate_advantages(
                paths,
                self.discount,
                self.gae_lambda,
                self.normalize_advantages,
            )

            advantages, returns, baselines = [], [], []
            for path in paths:
                advantages = np.append(advantages, path['advantages'])
                returns = np.append(returns, path['returns'])

            if epoch == 0 and self._need_to_update_eval_statistics:
                with torch.no_grad():
                    values = torch.squeeze(self.value_func(obs_tensor), dim=-1)
                    values_np = ptu.get_numpy(values)
                first_val_loss = ((returns - values_np)**2).mean()

            old_params = self.policy.get_param_values()

            num_policy_steps = len(advantages) // self.policy_batch_size
            for _ in range(num_policy_steps):
                if num_policy_steps == 1:
                    batch = dict(
                        observations=obs,
                        actions=actions,
                        advantages=advantages,
                    )
                else:
                    batch = ppp.sample_batch(
                        self.policy_batch_size,
                        observations=obs,
                        actions=actions,
                        advantages=advantages,
                    )
                policy_loss, kl = self.train_policy(batch, old_policy)

            with torch.no_grad():
                log_probs = self.policy.get_log_probs(
                    obs_tensor, act_tensor).squeeze(dim=-1)
            kl = (log_probs_old - log_probs).mean()

            if (self.target_kl is not None
                    and kl > 1.5 * self.target_kl) or (kl != kl):
                if epoch > 0 or kl != kl:  # nan check
                    self.policy.set_param_values(old_params)
                break

            num_value_steps = len(advantages) // self.value_batch_size
            for i in range(num_value_steps):
                batch = ppp.sample_batch(
                    self.value_batch_size,
                    observations=obs,
                    targets=returns,
                )
                value_loss = self.train_value(batch)
            rem_value_epochs -= 1

        # Ensure the value function is always updated for the maximum number
        # of epochs, regardless of if the policy wants to terminate early.
        for _ in range(rem_value_epochs):
            num_value_steps = len(advantages) // self.value_batch_size
            for i in range(num_value_steps):
                batch = ppp.sample_batch(
                    self.value_batch_size,
                    observations=obs,
                    targets=returns,
                )
                value_loss = self.train_value(batch)

        if self._need_to_update_eval_statistics:
            with torch.no_grad():
                _, _, _, log_pi, *_ = self.policy(obs_tensor,
                                                  return_log_prob=True)
                values = torch.squeeze(self.value_func(obs_tensor), dim=-1)
                values_np = ptu.get_numpy(values)

            errors = returns - values_np
            explained_variance = 1 - (np.var(errors) / np.var(returns))
            value_loss = errors**2

            self.eval_statistics['Num Epochs'] = epoch + 1

            self.eval_statistics['Policy Loss'] = ptu.get_numpy(
                policy_loss).mean()
            self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl).mean()
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantages',
                    advantages,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Returns',
                    returns,
                ))

            self.eval_statistics['Value Loss'] = value_loss.mean()
            self.eval_statistics['First Value Loss'] = first_val_loss
            self.eval_statistics[
                'Value Explained Variance'] = explained_variance
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Values',
                    ptu.get_numpy(values),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Value Squared Errors',
                    value_loss,
                ))
Esempio n. 20
0
    def train_from_torch(self, batch):
        # We only use the original batch to get the batch size for policy training
        """
        Generate synthetic data using dynamics model
        """
        if self._n_train_steps_total % self.rollout_generation_freq == 0:
            rollout_len = self.rollout_len_func(self._n_train_steps_total)
            total_samples = self.rollout_generation_freq * self.num_model_rollouts

            num_samples, generated_rewards, terminated = 0, np.array([]), []
            while num_samples < total_samples:
                batch_samples = min(self.rollout_batch_size,
                                    total_samples - num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = self.sample_paths(start_states, rollout_len)

                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    num_samples += len(path['observations'])
                    generated_rewards = np.concatenate(
                        [generated_rewards, path['rewards'][:, 0]])
                    terminated.append(path['terminals'][-1, 0] > 0.5)

                if num_samples >= total_samples:
                    break

            gt.stamp('generating rollouts', unique=False)
        """
        Update policy on both real and generated data
        """

        batch_size = batch['observations'].shape[0]
        n_real_data = int(self.real_data_pct * batch_size)
        n_generated_data = batch_size - n_real_data

        for _ in range(self.num_policy_updates):
            batch = self.replay_buffer.random_batch(n_real_data)
            generated_batch = self.generated_data_buffer.random_batch(
                n_generated_data)

            for k in ('rewards', 'terminals', 'observations', 'actions',
                      'next_observations'):
                batch[k] = np.concatenate((batch[k], generated_batch[k]),
                                          axis=0)
                batch[k] = ptu.from_numpy(batch[k])

            self.policy_trainer.train_from_torch(batch)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics and self._n_train_steps_total % self.rollout_generation_freq == 0:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['MBPO Rollout Length'] = rollout_len
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'MBPO Reward Predictions',
                    generated_rewards,
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'MBPO Rollout Terminations',
                    np.array(terminated).astype(float),
                ))

        self._n_train_steps_total += 1
Esempio n. 21
0
    def train_from_torch(self, batch):

        self._train_calls += 1
        if self._train_calls % self.train_every > 0:
            return

        rollout_len = self.rollout_len_func(self._n_train_steps_total)
        num_model_rollouts = max(self.num_model_samples // rollout_len, 1)
        self.eval_statistics['Rollout Length'] = rollout_len

        real_batch = self.replay_buffer.random_batch(num_model_rollouts)
        start_states = real_batch['observations']
        latents = self.generate_latents(start_states)

        observations = np.zeros((self.num_model_samples, self.obs_dim))
        next_observations = np.zeros((self.num_model_samples, self.obs_dim))
        actions = np.zeros((self.num_model_samples, self.action_dim))
        unfolded_latents = np.zeros((self.num_model_samples, self.latent_dim))
        disagreements = np.zeros(self.num_model_samples)

        num_samples, b_ind, num_traj = 0, 0, 0
        while num_samples < self.num_model_samples:
            e_ind = b_ind + 4192 // rollout_len
            with torch.no_grad():
                paths, path_disagreements = self.generate_paths(
                    dynamics_model=self.dynamics_model,
                    control_policy=self.control_policy,
                    start_states=start_states[b_ind:e_ind],
                    latents=ptu.from_numpy(latents[b_ind:e_ind]),
                    rollout_len=rollout_len,
                )

            b_ind = e_ind

            path_disagreements = ptu.get_numpy(path_disagreements)
            for i, path in enumerate(paths):
                clipped_len = min(
                    len(path['observations'] - (self.empowerment_horizon - 1)),
                    self.num_model_samples - num_samples)
                bi, ei = num_samples, num_samples + clipped_len

                if self.empowerment_horizon > 1:
                    path['observations'] = path['observations'][:-(
                        self.empowerment_horizon - 1)]
                    path['next_observations'] = path['next_observations'][(
                        self.empowerment_horizon -
                        1):(self.empowerment_horizon - 1) + clipped_len]
                    path['actions'] = path['actions'][:-(
                        self.empowerment_horizon - 1)]

                observations[bi:ei] = path['observations'][:clipped_len]
                next_observations[bi:ei] = path[
                    'next_observations'][:clipped_len]
                actions[bi:ei] = path['actions'][:clipped_len]
                unfolded_latents[bi:ei] = latents[num_traj:num_traj + 1]
                disagreements[bi:ei] = path_disagreements[i, :clipped_len]

                num_samples += clipped_len
                num_traj += 1

                if num_samples >= self.num_model_samples:
                    break

        gt.stamp('generating rollouts', unique=False)

        if not self.relabel_rewards:
            rewards, (
                logp, logp_altz,
                denom), reward_diagnostics = self.calculate_intrinsic_rewards(
                    observations, next_observations, unfolded_latents)
            orig_rewards = rewards.copy()
            rewards, postproc_dict = self.reward_postprocessing(
                rewards, reward_kwargs=dict(disagreements=disagreements))
            reward_diagnostics.update(postproc_dict)

            if self._need_to_update_eval_statistics:
                self.eval_statistics.update(reward_diagnostics)

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Log Pis',
                        logp,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Discriminator Alt Log Pis',
                        logp_altz,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Reward Denominator',
                        denom,
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Original)',
                        orig_rewards,
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Intrinsic Rewards (Processed)',
                        rewards,
                    ))

            gt.stamp('intrinsic reward calculation', unique=False)

        if self._need_to_update_eval_statistics:
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Latents',
                    latents,
                ))

        for t in range(self.num_model_samples):
            self.add_sample(
                observations[t],
                next_observations[t],
                next_observations[t],  # fix this
                actions[t],
                unfolded_latents[t],
                disagreement=disagreements[t],
            )

        gt.stamp('policy training', unique=False)

        self.train_discriminator(observations, next_observations,
                                 unfolded_latents)

        reward_kwargs = dict(
            disagreements=self._modeL_disagreements[:self._cur_replay_size])
        self.train_from_buffer(reward_kwargs=reward_kwargs)
def calculate_contrastive_empowerment(
    discriminator,
    obs,
    next_obs,
    latents,
    num_prior_samples=512,
    distribution_type='uniform',
    split_group=4096 * 32,
    obs_mean=None,
    obs_std=None,
    return_diagnostics=False,
    prior=None,
):
    """
    Described in Sharma et al 2019.
    Approximate variational lower bound using estimate of s' from s, z.
    Uses contrastive negatives to approximate denominator.
    """

    discriminator.eval()

    if obs_mean is not None:
        obs = (obs - obs_mean) / (obs_std + 1e-6)
        next_obs = (next_obs - obs_mean) / (obs_std + 1e-6)

    obs_deltas = ptu.from_numpy(next_obs - obs)
    obs_altz = np.concatenate([obs] * num_prior_samples, axis=0)

    with torch.no_grad():
        logp = discriminator.get_log_prob(
            ptu.from_numpy(obs),
            ptu.from_numpy(latents),
            obs_deltas,
        )
        logp = ptu.get_numpy(logp)

    if distribution_type == 'uniform':
        latent_altz = np.random.uniform(low=-1,
                                        high=1,
                                        size=(obs_altz.shape[0],
                                              latents.shape[1]))
    elif distribution_type == 'prior':
        if prior is None:
            raise AssertionError('prior specified but not passed in')
        obs_t = ptu.from_numpy(obs_altz)
        latent_altz, *_ = prior.get_action(obs_t, deterministic=False)
    else:
        raise NotImplementedError('distribution_type not found')

    # keep track of next obs/delta
    next_obs_altz = np.concatenate([next_obs - obs] * num_prior_samples,
                                   axis=0)

    with torch.no_grad():
        if obs_altz.shape[0] <= split_group:
            logp_altz = ptu.get_numpy(
                discriminator.get_log_prob(
                    ptu.from_numpy(obs_altz),
                    ptu.from_numpy(latent_altz),
                    ptu.from_numpy(next_obs_altz),
                ))
        else:
            logp_altz = []
            for split_idx in range(obs_altz.shape[0] // split_group):
                start_split = split_idx * split_group
                end_split = (split_idx + 1) * split_group
                logp_altz.append(
                    ptu.get_numpy(
                        discriminator.get_log_prob(
                            ptu.from_numpy(obs_altz[start_split:end_split]),
                            ptu.from_numpy(latent_altz[start_split:end_split]),
                            ptu.from_numpy(
                                next_obs_altz[start_split:end_split]),
                        )))
            if obs_altz.shape[0] % split_group:
                start_split = obs_altz.shape[0] % split_group
                logp_altz.append(
                    ptu.get_numpy(
                        discriminator.get_log_prob(
                            ptu.from_numpy(obs_altz[-start_split:]),
                            ptu.from_numpy(latent_altz[-start_split:]),
                            ptu.from_numpy(next_obs_altz[-start_split:]),
                        )))
            logp_altz = np.concatenate(logp_altz)
    logp_altz = np.array(np.array_split(logp_altz, num_prior_samples))

    if return_diagnostics:
        diagnostics = dict()
        orig_rep = np.repeat(np.expand_dims(logp, axis=0),
                             axis=0,
                             repeats=num_prior_samples)
        diagnostics['Pct Random Skills > Original'] = (orig_rep <
                                                       logp_altz).mean()

    # final DADS reward
    intrinsic_reward = np.log(num_prior_samples + 1) - np.log(1 + np.exp(
        np.clip(logp_altz - logp.reshape(1, -1), -50, 50)).sum(axis=0))

    if not return_diagnostics:
        return intrinsic_reward, (logp, logp_altz, logp - intrinsic_reward)
    else:
        return intrinsic_reward, (logp, logp_altz,
                                  logp - intrinsic_reward), diagnostics
Esempio n. 23
0
    def train_from_buffer(self,
                          replay_buffer,
                          holdout_pct=0.2,
                          max_grad_steps=1000,
                          epochs_since_last_update=5):
        self._n_train_steps_total += 1
        if self._n_train_steps_total % self.train_call_freq > 0 and self._n_train_steps_total > 1:
            return

        data = replay_buffer.get_transitions()
        x = data[:, :self.obs_dim + self.action_dim]  # inputs  s, a
        y = data[:, self.obs_dim + self.action_dim:]  # predict r, d, ns
        y[:,
          -self.obs_dim:] -= x[:, :self.obs_dim]  # predict delta in the state

        # normalize network inputs
        self.ensemble.fit_input_stats(x)

        # generate holdout set
        inds = np.random.permutation(data.shape[0])
        x, y = x[inds], y[inds]

        n_train = max(int((1 - holdout_pct) * data.shape[0]),
                      data.shape[0] - 8092)
        n_test = data.shape[0] - n_train

        x_train, y_train = x[:n_train], y[:n_train]
        x_test, y_test = x[n_train:], y[n_train:]
        x_test, y_test = ptu.from_numpy(x_test), ptu.from_numpy(y_test)

        # train until holdout set convergence
        num_epochs, num_steps = 0, 0
        num_epochs_since_last_update = 0
        best_holdout_loss = float('inf')
        num_batches = int(np.ceil(n_train / self.batch_size))

        while num_epochs_since_last_update < epochs_since_last_update and num_steps < max_grad_steps:
            # generate idx for each model to bootstrap
            self.ensemble.train()
            for b in range(num_batches):
                b_idxs = np.random.randint(n_train,
                                           size=(self.ensemble_size *
                                                 self.batch_size))
                x_batch, y_batch = x_train[b_idxs], y_train[b_idxs]
                x_batch, y_batch = ptu.from_numpy(x_batch), ptu.from_numpy(
                    y_batch)
                x_batch = x_batch.view(self.ensemble_size, self.batch_size, -1)
                y_batch = y_batch.view(self.ensemble_size, self.batch_size, -1)
                loss = self.ensemble.get_loss(x_batch, y_batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            num_steps += num_batches

            # stop training based on holdout loss improvement
            self.ensemble.eval()
            with torch.no_grad():
                holdout_losses, holdout_errors = self.ensemble.get_loss(
                    x_test, y_test, split_by_model=True, return_l2_error=True)
            holdout_loss = sum(
                sorted(holdout_losses)[:self.num_elites]) / self.num_elites

            if num_epochs == 0 or \
               (best_holdout_loss - holdout_loss) / abs(best_holdout_loss) > 0.01:
                best_holdout_loss = holdout_loss
                num_epochs_since_last_update = 0
            else:
                num_epochs_since_last_update += 1

            num_epochs += 1

        self.ensemble.elites = np.argsort(holdout_losses)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['Model Elites Holdout Loss'] = \
                np.mean(ptu.get_numpy(holdout_loss))
            self.eval_statistics['Model Holdout Loss'] = \
                np.mean(ptu.get_numpy(sum(holdout_losses))) / self.ensemble_size
            self.eval_statistics['Model Training Epochs'] = num_epochs
            self.eval_statistics['Model Training Steps'] = num_steps

            for i in range(self.ensemble_size):
                name = 'M%d' % (i + 1)
                self.eval_statistics[name + ' Loss'] = \
                    np.mean(ptu.get_numpy(holdout_losses[i]))
                self.eval_statistics[name + ' L2 Error'] = \
                    np.mean(ptu.get_numpy(holdout_errors[i]))