예제 #1
0
    def estimate_objective(self, state, action, target=False):
        """
        Estimates the objective (state-value).

        Args:
            state (torch.Tensor): state of shape [batch_size, n_state_dims]
            actions (torch.Tensor): action of shape [n_action_samples * batch_size, n_action_dims]
            target (bool): whether to use the target approx post

        Returns objective estimate of shape [n_action_samples * batch_size, 1]
        """
        pessimism = self.pessimism if self.mode == 'train' else -self.optimism
        approx_post = self.target_approx_post if target else self.approx_post
        kl = kl_divergence(approx_post,
                           self.prior,
                           n_samples=self.n_action_samples,
                           sample=action).sum(dim=1, keepdim=True)
        expanded_state = state.repeat(self.n_action_samples, 1)
        cond_log_like = self.q_value_estimator(self,
                                               expanded_state,
                                               action,
                                               detach_params=True,
                                               target=self.optimize_targets,
                                               pessimism=pessimism)
        objective = cond_log_like - self.alphas['pi'] * kl.repeat(
            self.n_action_samples, 1)
        if self.inf_target_kl and not target and self.mode == 'train':
            # KL from target approx. posterior
            self.target_approx_post.reset(
                self.target_approx_post._batch_size,
                dist_params={
                    'loc': self.target_approx_post.dist.loc.detach(),
                    'scale': self.target_approx_post.dist.scale.detach()
                })
            inf_kl = kl_divergence(approx_post,
                                   self.target_approx_post,
                                   n_samples=self.n_action_samples,
                                   sample=action).sum(dim=1, keepdim=True)
            objective = objective - self.alphas['target_inf'] * inf_kl.repeat(
                self.n_action_samples, 1)
        return objective
예제 #2
0
def estimate_agent_kl(env, agent, prev_episode):
    """
    Estimate the change in the agent's policy from the last collected episode.
    Estimated using D_KL (pi_old || pi_new), sampling from previous episode.

    Args:
        env (gym.Env):
        agent (Agent): the most recent version of the agent
        prev_episode (dict): the previously collected episode
    """
    if prev_episode is None:
        return 0.

    # create a distribution to hold old approx post params
    agent_args = get_agent_args(env)
    agent_args['approx_post_args']['n_input'] = None
    old_approx_post = Distribution(**agent_args['approx_post_args']).to(agent.device)

    agent.reset(); agent.eval()

    states = prev_episode['state']
    dist_params = prev_episode['distributions']['action']['approx_post']

    agent_kl = 0

    for timestep in range(prev_episode['state'].shape[0]):

        state = states[timestep:timestep+1]
        params = {k: v[timestep:timestep+1].to(agent.device) for k, v in dist_params.items()}

        old_approx_post.reset(dist_params=params)
        agent.act(state)
        kl = kl_divergence(old_approx_post, agent.approx_post).sum().detach().item()
        agent_kl += kl

    agent_kl /= prev_episode['state'].shape[0]

    return agent_kl
예제 #3
0
    def forward(self,
                agent,
                state,
                action,
                target=False,
                both=False,
                detach_params=False,
                direct=False,
                pessimism=1,
                *args,
                **kwargs):
        """
        Estimates the Q-value using the state and action using model and Q-networks.

        Args:
            state (torch.Tensor): the state [batch_size * n_action_samples, state_dim]
            action (torch.Tensor): the action [batch_size * n_action_samples, action_dim]
            target (bool): whether to use the target networks
            both (bool): whether to return both values (or the min value)
            detach_params (bool): whether to use detached (copied) parameters
            direct (bool): whether to get the direct (network) estimate
            pessimism (float): value estimate uncertainty penalty

        Returns a Q-value estimate of shape [n_action_samples * batch_size, 1]
        """
        if direct:
            return self.direct_estimate(agent, state, action, target, both,
                                        detach_params, pessimism)

        if target:
            q_value_models = self.target_q_value_models
            q_value_variables = self.target_q_value_variables
        else:
            q_value_models = self.q_value_models
            q_value_variables = self.q_value_variables
        if detach_params:
            q_value_models = copy.deepcopy(q_value_models)
            q_value_variables = copy.deepcopy(q_value_variables)

        self.planning_mode(agent)
        # set the previous state for residual state prediction
        self.state_variable.cond_likelihood.set_prev_x(state)
        # roll out the model
        actions_list = [action]
        states_list = [state]
        rewards_list = []
        kl_list = []
        q_values_list = []
        for _ in range(self.horizon):
            # estimate the Q-value at current state
            action = action.tanh() if agent.postprocess_action else action
            q_value_input = [
                model(state=state, action=action) for model in q_value_models
            ]
            q_values = [
                variable(inp)
                for variable, inp in zip(q_value_variables, q_value_input)
            ]
            # q_value = torch.min(q_values[0], q_values[1])
            q_values = torch.cat(q_values, dim=1)
            q_mean = q_values.mean(dim=1, keepdim=True)
            q_std = (q_values.var(dim=1, keepdim=True) + 1e-6).sqrt()
            q_value = q_mean - pessimism * q_std

            q_values_list.append(q_value)
            # predict state and reward
            self.generate_state(state, action, detach_params)
            self.generate_reward(state, action, detach_params)
            reward = self.reward_variable.sample()
            rewards_list.append(reward)
            state = self.state_variable.sample()
            states_list.append(state)
            # generate the action
            agent.generate_prior(state, detach_params)
            if agent.prior_model is not None:
                # sample from the learned prior
                action = agent.prior.sample()
                kl_list.append(
                    torch.zeros(action.shape[0], 1, device=action.device))
            else:
                # estimate approximate posterior
                agent.inference(state, detach_params, direct=True)
                dist = agent.direct_approx_post if agent.direct_approx_post is not None else agent.approx_post
                action = dist.sample()
                # calculate KL divergence
                kl = kl_divergence(dist,
                                   agent.prior,
                                   n_samples=1,
                                   sample=action).sum(dim=1, keepdim=True)
                kl_list.append(agent.alphas['pi'] * kl)

        actions_list.append(action)
        # estimate Q-value at final state
        action = action.tanh() if agent.postprocess_action else action
        q_value_input = [
            model(state=state, action=action) for model in q_value_models
        ]
        q_values = [
            variable(inp)
            for variable, inp in zip(q_value_variables, q_value_input)
        ]
        # q_value = torch.min(q_values[0], q_values[1])
        q_values = torch.cat(q_values, dim=1)
        q_mean = q_values.mean(dim=1, keepdim=True)
        q_std = (q_values.var(dim=1, keepdim=True) + 1e-6).sqrt()
        q_value = q_mean - pessimism * q_std
        q_values_list.append(q_value)

        # calculate the Q-value estimate
        total_rewards = torch.stack(rewards_list)
        total_kl = torch.stack(kl_list)
        total_q_values = torch.stack(q_values_list)

        if self.value_estimate == 'n_step':
            estimate = n_step(total_q_values,
                              total_rewards,
                              total_kl,
                              discount=agent.reward_discount)
        elif self.value_estimate == 'average_n_step':
            estimate = average_n_step(total_q_values,
                                      total_rewards,
                                      total_kl,
                                      discount=agent.reward_discount)
        elif self.value_estimate == 'exp_average_n_step':
            estimate = exp_average_n_step(total_q_values,
                                          total_rewards,
                                          total_kl,
                                          discount=agent.reward_discount,
                                          factor=1.)
        elif self.value_estimate == 'retrace':
            estimate = retrace_n_step(total_q_values,
                                      total_rewards,
                                      total_kl,
                                      discount=agent.reward_discount,
                                      factor=agent.retrace_lambda)
        else:
            raise NotImplementedError

        self.acting_mode(agent)

        # self.rollout_states.append(states_list)
        # self.rollout_rewards.append(rewards_list)
        # self.rollout_q_values.append(q_values_list)
        # self.rollout_actions.append(actions_list)

        return estimate
예제 #4
0
    def _collect_kl_objectives(self, on_policy_action, target_on_policy_action,
                               valid, done):
        """
        Collect the KL divergence objectives to train the prior.
        """
        if self.agent.target_prior_model is not None:
            batch_size = self.agent.prior._batch_size
            # get the distribution parameters
            target_prior_loc = self.agent.target_prior.dist.loc
            target_prior_scale = self.agent.target_prior.dist.scale
            current_prior_loc = self.agent.prior.dist.loc
            current_prior_scale = self.agent.prior.dist.scale
            if 'loc' in dir(self.agent.approx_post.dist):
                post_loc = self.agent.approx_post.dist.loc
                post_scale = self.agent.approx_post.dist.scale
                self.agent.approx_post.reset(batch_size,
                                             dist_params={
                                                 'loc': post_loc.detach(),
                                                 'scale': post_scale.detach()
                                             })

            # decoupled updates on the prior
            # loc KLs
            self.agent.prior.reset(batch_size,
                                   dist_params={
                                       'loc': current_prior_loc,
                                       'scale': target_prior_scale.detach()
                                   })
            kl_prev_loc = kl_divergence(
                self.agent.target_prior,
                self.agent.prior,
                n_samples=self.agent.n_action_samples).sum(dim=1, keepdim=True)
            kl_curr_loc = kl_divergence(self.agent.approx_post,
                                        self.agent.prior,
                                        n_samples=self.agent.n_action_samples,
                                        sample=on_policy_action).sum(
                                            dim=1, keepdim=True)
            # scale KLs
            self.agent.prior.reset(batch_size,
                                   dist_params={
                                       'loc': target_prior_loc.detach(),
                                       'scale': current_prior_scale
                                   })
            kl_prev_scale = kl_divergence(
                self.agent.target_prior,
                self.agent.prior,
                n_samples=self.agent.n_action_samples).sum(dim=1, keepdim=True)
            kl_curr_scale = kl_divergence(
                self.agent.approx_post,
                self.agent.prior,
                n_samples=self.agent.n_action_samples,
                sample=on_policy_action).sum(dim=1, keepdim=True)

            # append the KL objectives
            self.objectives['action_kl_prev_loc'].append(
                self.agent.alphas['loc'] * kl_prev_loc * (1 - done) * valid)
            self.objectives['action_kl_curr_loc'].append(
                self.agent.alphas['pi'] * kl_curr_loc * (1 - done) * valid)
            self.objectives['action_kl_prev_scale'].append(
                self.agent.alphas['scale'] * kl_prev_scale * (1 - done) *
                valid)
            self.objectives['action_kl_curr_scale'].append(
                self.agent.alphas['pi'] * kl_curr_scale * (1 - done) * valid)

            # report the KL divergences
            self.metrics['action']['kl_prev_loc'].append(
                (kl_prev_loc * (1 - done) * valid).detach())
            self.metrics['action']['kl_curr_loc'].append(
                (kl_curr_loc * (1 - done) * valid).detach())
            self.metrics['action']['kl_prev_scale'].append(
                (kl_prev_scale * (1 - done) * valid).detach())
            self.metrics['action']['kl_curr_scale'].append(
                (kl_curr_scale * (1 - done) * valid).detach())

            # report the prior and target prior distribution parameters
            self.metrics['action']['prior_prev_loc'].append(
                target_prior_loc.detach().mean(dim=1, keepdim=True))
            self.metrics['action']['prior_prev_scale'].append(
                target_prior_scale.detach().mean(dim=1, keepdim=True))
            self.metrics['action']['prior_curr_loc'].append(
                current_prior_loc.detach().mean(dim=1, keepdim=True))
            self.metrics['action']['prior_curr_scale'].append(
                current_prior_scale.detach().mean(dim=1, keepdim=True))

            # reset the prior with detached parameters and approx. post. with
            # non-detached parameters to evaluate KL for approx. post.
            self.agent.prior.reset(batch_size,
                                   dist_params={
                                       'loc': current_prior_loc.detach(),
                                       'scale': current_prior_scale.detach()
                                   })
            if 'loc' in dir(self.agent.approx_post.dist):
                self.agent.approx_post.reset(batch_size,
                                             dist_params={
                                                 'loc': post_loc,
                                                 'scale': post_scale
                                             })

        if 'loc' in dir(self.agent.approx_post.dist):
            # report the approx. post. distribution parameters
            current_post_loc = self.agent.approx_post.dist.loc
            current_post_scale = self.agent.approx_post.dist.scale
            self.metrics['action']['approx_post_loc'].append(
                current_post_loc.detach().mean(dim=1, keepdim=True))
            self.metrics['action']['approx_post_scale'].append(
                current_post_scale.detach().mean(dim=1, keepdim=True))

        # evaluate the KL for reporting (and possibly Q-network targets)
        kl = kl_divergence(self.agent.approx_post,
                           self.agent.prior,
                           n_samples=self.agent.n_action_samples,
                           sample=on_policy_action).sum(dim=1, keepdim=True)
        self.metrics['action']['kl'].append((kl * (1 - done) * valid).detach())

        if self.agent.direct_approx_post is not None or (
                self.agent.target_approx_post is not None
                and self.agent.target_inf_value_targets):
            # evaluate the KL for the direct approx. post.
            approx_post = self.agent.direct_approx_post if self.agent.direct_approx_post is not None else self.agent.target_approx_post
            kl = kl_divergence(approx_post,
                               self.agent.prior,
                               n_samples=self.agent.n_action_samples,
                               sample=target_on_policy_action).sum(
                                   dim=1, keepdim=True)
            self.metrics['action']['target_kl'].append(
                (kl * (1 - done) * valid).detach())

        if self.agent.inf_target_kl:
            # KL between approx post and target approx post
            self.agent.target_approx_post.reset(
                self.agent.target_approx_post._batch_size,
                dist_params={
                    'loc': self.agent.target_approx_post.dist.loc.detach(),
                    'scale': self.agent.target_approx_post.dist.scale.detach()
                })
            kl = kl_divergence(self.agent.approx_post,
                               self.agent.target_approx_post,
                               n_samples=self.agent.n_action_samples,
                               sample=on_policy_action).sum(dim=1,
                                                            keepdim=True)
            self.metrics['action']['kl_target_inf'].append(
                (kl * (1 - done) * valid).detach())
예제 #5
0
    def _collect_inf_opt_objective(self, state, on_policy_action, valid, done):
        """
        Evaluates the inference optimizer if there are amortized parameters.
        """
        if 'parameters' in dir(self.agent.inference_optimizer):
            # detach the prior
            if self.agent.prior_model is not None:
                batch_size = self.agent.prior._batch_size
                loc = self.agent.prior.dist.loc
                scale = self.agent.prior.dist.scale
                self.agent.prior.reset(batch_size,
                                       dist_params={
                                           'loc': loc.detach(),
                                           'scale': scale.detach()
                                       })
            # evaluate the objective
            obj = self.agent.estimate_objective(state, on_policy_action)
            obj = obj.view(self.agent.n_action_samples, -1, 1).mean(dim=0)
            if self.agent.inference_optimizer.n_inf_iters > 1:
                # append final objective, calculate inference improvement
                self.agent.inference_optimizer.estimated_objectives.append(
                    -obj.detach())
                objectives = torch.stack(
                    self.agent.inference_optimizer.estimated_objectives)
                inf_imp = objectives[0] - objectives[-1]
                self.inference_improvement.append(inf_imp)
            # note: multiply by batch size because we divide later (in optimizer)
            obj = -obj * valid * (1 - done) * self.agent.batch_size
            self.objectives['inf_opt_obj'].append(obj)
            # re-attach the prior
            if self.agent.prior_model is not None:
                self.agent.prior.reset(batch_size,
                                       dist_params={
                                           'loc': loc,
                                           'scale': scale
                                       })

        if self.agent.direct_approx_post is not None:
            # train the direct inference model using on policy actions
            # log_prob = self.agent.direct_approx_post.log_prob(on_policy_action).sum(dim=2)
            # log_prob = log_prob.view(self.agent.n_action_samples, -1, 1).mean(dim=0)
            # self.objectives['direct_inf_opt_obj'].append(-log_prob * valid * (1 - done))
            batch_size = self.agent.approx_post._batch_size
            loc = self.agent.approx_post.dist.loc
            scale = self.agent.approx_post.dist.scale
            self.agent.approx_post.reset(batch_size,
                                         dist_params={
                                             'loc': loc.detach(),
                                             'scale': scale.detach()
                                         })
            kl = kl_divergence(self.agent.approx_post,
                               self.agent.direct_approx_post,
                               n_samples=self.agent.n_action_samples,
                               sample=on_policy_action).sum(dim=1,
                                                            keepdim=True)
            self.objectives['direct_inf_opt_obj'].append(kl * valid *
                                                         (1 - done))
            self.agent.approx_post.reset(batch_size,
                                         dist_params={
                                             'loc': loc,
                                             'scale': scale
                                         })
            self.metrics['action']['direct_kl'].append(
                (kl * (1 - done) * valid).detach())
예제 #6
0
def compare_policies(exp_key1, exp_key2, write_result=True):
    """
    Compares the policies of two agents at the end of training.

    Args:
        exp_key1 (str):
        exp_key2 (str):
        write_result (bool)
    """
    # load the experiments
    comet_api = comet_ml.API(api_key=LOADING_API_KEY)
    exp1 = comet_api.get_experiment(project_name=PROJECT_NAME,
                                    workspace=WORKSPACE,
                                    experiment=exp_key1)
    exp2 = comet_api.get_experiment(project_name=PROJECT_NAME,
                                    workspace=WORKSPACE,
                                    experiment=exp_key2)

    # create the environment
    param_summary = exp1.get_parameters_summary()
    env_name = [a for a in param_summary if a['name'] == 'env'][0]['valueCurrent']
    env1 = create_env(env_name)
    env2 = create_env(env_name)

    # create the agents
    asset_list = exp1.get_asset_list()
    agent_config_asset_list = [a for a in asset_list if 'agent_args' in a['fileName']]
    agent_args = None
    if len(agent_config_asset_list) > 0:
        # if we've saved the agent config dict, load it
        agent_args = exp1.get_asset(agent_config_asset_list[0]['assetId'])
        agent_args = json.loads(agent_args)
        agent_args = agent_args if 'opt_type' in agent_args['inference_optimizer_args'] else None
    agent1 = create_agent(env1, agent_args=agent_args)[0]
    load_checkpoint(agent1, exp_key1)

    asset_list = exp2.get_asset_list()
    agent_config_asset_list = [a for a in asset_list if 'agent_args' in a['fileName']]
    agent_args = None
    if len(agent_config_asset_list) > 0:
        # if we've saved the agent config dict, load it
        agent_args = exp2.get_asset(agent_config_asset_list[0]['assetId'])
        agent_args = json.loads(agent_args)
        agent_args = agent_args if 'opt_type' in agent_args['inference_optimizer_args'] else None
    agent2 = create_agent(env1, agent_args=agent_args)[0]
    load_checkpoint(agent2, exp_key2)

    # evaluate the KL between policies
    kl12 = []
    kl21 = []
    agent1.reset(); agent1.eval()
    agent2.reset(); agent2.eval()

    state1 = env1.reset()
    state2 = env2.reset()

    for state_ind in range(N_STATES):
        # perform policy optimization on state1
        action1 = agent1.act(state1)
        agent2.act(state1)
        kl = kl_divergence(agent1.approx_post, agent2.approx_post).sum().detach().item()
        kl12.append(kl)

        agent1.reset(); agent1.eval()
        agent2.reset(); agent2.eval()

        # perform policy optimization on state2
        agent1.act(state2)
        action2 = agent2.act(state2)
        kl = kl_divergence(agent2.approx_post, agent1.approx_post).sum().detach().item()
        kl21.append(kl)

        # step the environments
        state1, _, done1, _ = env1.step(action1)
        state2, _, done2, _ = env2.step(action2)

        if done1:
            agent1.reset(); agent1.eval()
            state1 = env1.reset()
            done1 = False
        if done2:
            agent2.reset(); agent2.eval()
            state2 = env2.reset()
            done2 = False

    kls = {'kl12': kl12,
           'kl21': kl21}

    if write_result:
        pickle.dump(kls, open('policy_kl_' + exp_key1 + '_' + exp_key2 + '.p', 'wb'))

    return kls
예제 #7
0
def estimate_monte_carlo_return(env, agent, env_state, state, action,
                                n_batches):
    """
    Estimates the discounted Monte Carlo return (including KL) for a policy from
    a state-action pair.

    Args:
        env (gym.Env): the environment
        agent (Agent): the agent
        env_state (tuple): the environment state from MuJoCo (qpos, qvel)
        state
        action (np.array): the action of size [1, n_action_dims]
        n_batches (int): the number of batches of Monte Carlo roll-outs

    Returns numpy array of returns of size [n_batches * ROLLOUT_BATCH_SIZE].
    """
    total_samples = n_batches * ROLLOUT_BATCH_SIZE
    returns = np.zeros(total_samples)
    initial_action = action.repeat(ROLLOUT_BATCH_SIZE, 1).numpy()
    # create a synchronous environment to perform ROLLOUT_BATCH_SIZE roll-outs
    env = SynchronousEnv(env, ROLLOUT_BATCH_SIZE)

    for return_batch_num in range(n_batches):
        if return_batch_num % 1 == 0:
            print('     Batch ' + str(return_batch_num + 1) + ' of ' +
                  str(n_batches) + '.')
        agent.reset(batch_size=ROLLOUT_BATCH_SIZE)
        agent.eval()
        # set the environment
        env.reset()
        qpos, qvel = env_state
        env.set_state(qpos=qpos, qvel=qvel)
        state, reward, done, _ = env.step(initial_action)
        # rollout the environment, get return
        rewards = [reward.view(-1).numpy()]
        kls = [np.zeros(ROLLOUT_BATCH_SIZE)]
        n_steps = 1
        while not done.prod():
            if n_steps > 1000:
                break
            action = agent.act(state, reward, done)
            state, reward, done, _ = env.step(action)
            rewards.append(((1 - done) * reward).view(-1).numpy())
            kl = kl_divergence(agent.approx_post,
                               agent.prior,
                               n_samples=agent.n_action_samples).sum(
                                   dim=1, keepdim=True)
            kls.append(((1 - done) * kl.detach().cpu()).view(-1).numpy())
            n_steps += 1
        rewards = np.stack(rewards)
        kls = np.stack(kls)
        discounts = np.cumprod(agent.reward_discount * np.ones(kls.shape),
                               axis=0)
        discounts = np.concatenate(
            [np.ones((1, ROLLOUT_BATCH_SIZE)),
             discounts])[:-1].reshape(-1, ROLLOUT_BATCH_SIZE)
        rewards = discounts * (rewards -
                               agent.alphas['pi'].cpu().numpy() * kls)
        sample_returns = np.sum(rewards, axis=0)
        sample_ind = return_batch_num * ROLLOUT_BATCH_SIZE
        returns[sample_ind:sample_ind + ROLLOUT_BATCH_SIZE] = sample_returns
    return returns