Example #1
0
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     # This adds the "advantages" column to the sample batch
     return compute_advantages(
         sample_batch, 0.0, self.config["gamma"], use_gae=False)
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     completed = sample_batch["dones"][-1]
     if completed:
         last_r = 0.0
     else:
         last_r = self._value(sample_batch["new_obs"][-1])
     return compute_advantages(sample_batch, last_r, self.config["gamma"],
                               self.config["lambda"])
Example #3
0
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     completed = sample_batch["dones"][-1]
     if completed:
         last_r = 0.0
     else:
         next_state = []
         for i in range(len(self.model.state_in)):
             next_state.append([sample_batch["state_out_{}".format(i)][-1]])
         last_r = self._value(sample_batch["new_obs"][-1], *next_state)
     return compute_advantages(sample_batch, last_r, self.config["gamma"],
                               self.config["lambda"])
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     completed = sample_batch["dones"][-1]
     if completed:
         last_r = 0.0
     else:
         raise NotImplementedError(
             "last done mask in a batch should be True. "
             "For now, we only support reading experience batches produced "
             "with batch_mode='complete_episodes'.",
             len(sample_batch["dones"]), sample_batch["dones"][-1])
     batch = compute_advantages(
         sample_batch, last_r, gamma=self.config["gamma"], use_gae=False)
     return batch
Example #5
0
def postprocess_ppo_gae(policy,
                        sample_batch,
                        other_agent_batches=None,
                        episode=None):
    """Adds the policy logits, VF preds, and advantages to the trajectory."""

    completed = sample_batch["dones"][-1]
    if completed:
        last_r = 0.0
    else:
        next_state = []
        for i in range(len(policy.state_in)):
            next_state.append([sample_batch["state_out_{}".format(i)][-1]])
        last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
                               sample_batch[SampleBatch.ACTIONS][-1],
                               sample_batch[SampleBatch.REWARDS][-1],
                               *next_state)
    batch = compute_advantages(sample_batch,
                               last_r,
                               policy.config["gamma"],
                               policy.config["lambda"],
                               use_gae=policy.config["use_gae"])
    return batch
Example #6
0
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     if not self.config["vtrace"]:
         completed = sample_batch["dones"][-1]
         if completed:
             last_r = 0.0
         else:
             next_state = []
             for i in range(len(self.model.state_in)):
                 next_state.append(
                     [sample_batch["state_out_{}".format(i)][-1]])
             last_r = self.value(sample_batch["new_obs"][-1], *next_state)
         batch = compute_advantages(sample_batch,
                                    last_r,
                                    self.config["gamma"],
                                    self.config["lambda"],
                                    use_gae=self.config["use_gae"])
     else:
         batch = sample_batch
     del batch.data["new_obs"]  # not used, so save some bandwidth
     return batch
Example #7
0
def centralized_critic_postprocessing(policy,
                                      sample_batch,
                                      other_agent_batches=None,
                                      episode=None):
    if policy.loss_initialized():
        assert other_agent_batches is not None
        [(_, opponent_batch)] = list(other_agent_batches.values())

        # also record the opponent obs and actions in the trajectory
        sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
        sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]

        # overwrite default VF prediction with the central VF
        sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
            sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
            sample_batch[OPPONENT_ACTION])
    else:
        # policy hasn't initialized yet, use zeros
        sample_batch[OPPONENT_OBS] = np.zeros_like(
            sample_batch[SampleBatch.CUR_OBS])
        sample_batch[OPPONENT_ACTION] = np.zeros_like(
            sample_batch[SampleBatch.ACTIONS])
        sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
            sample_batch[SampleBatch.REWARDS], dtype=np.float32)

    completed = sample_batch["dones"][-1]
    if completed:
        last_r = 0.0
    else:
        last_r = sample_batch[SampleBatch.VF_PREDS][-1]

    train_batch = compute_advantages(sample_batch,
                                     last_r,
                                     policy.config["gamma"],
                                     policy.config["lambda"],
                                     use_gae=policy.config["use_gae"])
    return train_batch
Example #8
0
    def postprocess_trajectory(self,
                               sample_batch,
                               other_agent_batches=None,
                               episode=None):
        sample_batch = super().postprocess_trajectory(
            sample_batch,
            other_agent_batches=other_agent_batches,
            episode=episode)

        last_obs = self.convert_to_tensor(
            sample_batch[SampleBatch.NEXT_OBS][-1])
        last_r = self.module.critic(last_obs).squeeze(-1).numpy()

        cur_obs = self.convert_to_tensor(sample_batch[SampleBatch.CUR_OBS])
        sample_batch[SampleBatch.VF_PREDS] = (
            self.module.critic(cur_obs).squeeze(-1).numpy())
        sample_batch = compute_advantages(
            sample_batch,
            last_r,
            gamma=self.config["gamma"],
            lambda_=self.config["lambda"],
            use_gae=self.config["use_gae"],
        )
        return sample_batch
Example #9
0
def original_postprocess(policy,
                         sample_batch,
                         other_agent_batches=None,
                         episode=None):
    if not policy.config["vtrace"]:
        completed = sample_batch["dones"][-1]
        if completed:
            last_r = 0.0
        else:
            next_state = []
            for i in range(policy.num_state_tensors()):
                next_state.append([sample_batch["state_out_{}".format(i)][-1]])
            last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
                                   sample_batch[SampleBatch.ACTIONS][-1],
                                   sample_batch[SampleBatch.REWARDS][-1],
                                   *next_state)
        batch = compute_advantages(sample_batch,
                                   last_r,
                                   policy.config["gamma"],
                                   policy.config["lambda"],
                                   use_gae=policy.config["use_gae"])
    else:
        batch = sample_batch
    return batch
Example #10
0
    def postprocess_trajectory(
        self,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
        episode: Optional["Episode"] = None,
    ):
        sample_batch = super().postprocess_trajectory(
            sample_batch, other_agent_batches, episode
        )

        # Trajectory is actually complete -> last r=0.0.
        if sample_batch[SampleBatch.DONES][-1]:
            last_r = 0.0
        # Trajectory has been truncated -> last r=VF estimate of last obs.
        else:
            # Input dict is provided to us automatically via the Model's
            # requirements. It's a single-timestep (last one in trajectory)
            # input_dict.
            # Create an input dict according to the Model's requirements.
            index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
            input_dict = sample_batch.get_single_step_input_dict(
                self.model.view_requirements, index=index
            )
            last_r = self._value(**input_dict)

        # Adds the "advantages" (which in the case of MARWIL are simply the
        # discounted cumulative rewards) to the SampleBatch.
        return compute_advantages(
            sample_batch,
            last_r,
            self.config["gamma"],
            # We just want the discounted cumulative rewards, so we won't need
            # GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
            use_gae=False,
            use_critic=False,
        )
Example #11
0
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     if not self.config["vtrace"]:
         completed = sample_batch["dones"][-1]
         if completed:
             last_r = 0.0
         else:
             next_state = []
             for i in range(len(self.model.state_in)):
                 next_state.append(
                     [sample_batch["state_out_{}".format(i)][-1]])
             last_r = self.value(sample_batch["new_obs"][-1], *next_state)
         batch = compute_advantages(
             sample_batch,
             last_r,
             self.config["gamma"],
             self.config["lambda"],
             use_gae=self.config["use_gae"])
     else:
         batch = sample_batch
     del batch.data["new_obs"]  # not used, so save some bandwidth
     return batch
Example #12
0
def add_advantages(policy,
                   sample_batch,
                   other_agent_batches=None,
                   episode=None):

    completed = sample_batch[SampleBatch.DONES][-1]
    if completed:
        last_r = 0.0
    else:
        # Trajectory has been truncated, estimate final reward using the
        # value function from the terminal observation and
        # internal recurrent state if any
        next_state = []
        for i in range(policy.num_state_tensors()):
            next_state.append(sample_batch['state_out_{}'.format(i)][-1])
        last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
                               sample_batch[SampleBatch.ACTIONS][-1],
                               sample_batch[SampleBatch.REWARDS][-1],
                               *next_state)

    return compute_advantages(sample_batch, last_r, policy.config['gamma'],
                              policy.config['lambda'],
                              policy.config['use_gae'],
                              policy.config['use_critic'])
Example #13
0
def postprocesses_trajectories(policy,
                               sample_batch,
                               other_agent_batches=None,
                               episode=None):
    """
    Postprocesses individual trajectories.

    Inputs are numpy arrays with shape [Time, Feature Dims...] or [Time]
    if there is only one feature. Note that inputs are not batched.

    Computes advantages.
    """
    horizon = policy.config['fun_horizon']
    seq_len = sample_batch[SampleBatch.REWARDS].shape[0]

    manager_latent_state = torch.Tensor(sample_batch['manager_latent_state'])
    manager_goal = torch.Tensor(sample_batch['manager_goal'])

    fun_intrinsic_reward = np.zeros_like(sample_batch[SampleBatch.REWARDS])
    for i in range(seq_len):
        reward = 0.0
        for j in range(1, horizon + 1):
            if i - j >= 0:
                manager_latent_state_current = manager_latent_state[i]
                manager_latent_state_prev = manager_latent_state[i - j]
                manager_latent_state_diff = manager_latent_state_current - manager_latent_state_prev
                manager_goal_prev = manager_goal[i - j]
                reward = reward + F.cosine_similarity(
                    manager_latent_state_diff, manager_goal_prev, dim=0)
        fun_intrinsic_reward[i] = reward / horizon

    sample_batch['fun_intrinsic_reward'] = fun_intrinsic_reward

    completed = sample_batch[SampleBatch.DONES][-1]
    if completed:
        manager_last_r = 0.0
        worker_last_r = 0.0
    else:
        # Trajectory has been truncated, estimate final reward using the
        # value function from the terminal observation and
        # internal recurrent state if any
        next_state = []
        for i in range(policy.num_state_tensors()):
            next_state.append(sample_batch['state_out_{}'.format(i)][-1])
        manager_last_r, worker_last_r = policy._value(
            sample_batch[SampleBatch.NEXT_OBS][-1],
            sample_batch[SampleBatch.ACTIONS][-1],
            sample_batch[SampleBatch.REWARDS][-1], *next_state)
        manager_last_r = manager_last_r[0]
        worker_last_r = worker_last_r[0]

    # Compute advantages and value targets for the manager
    sample_batch[SampleBatch.VF_PREDS] = sample_batch['manager_values']
    sample_batch = compute_advantages(sample_batch, manager_last_r,
                                      policy.config['gamma'],
                                      policy.config['lambda'],
                                      policy.config['use_gae'],
                                      policy.config['use_critic'])
    sample_batch['manager_advantages'] = sample_batch[
        Postprocessing.ADVANTAGES]
    sample_batch['manager_value_targets'] = sample_batch[
        Postprocessing.VALUE_TARGETS]

    sample_batch[SampleBatch.REWARDS] += 0.9 * fun_intrinsic_reward
    # sample_batch[SampleBatch.REWARDS] = np.clip(
    #     sample_batch[SampleBatch.REWARDS], -1, 1)

    # Compute advantages and value targets for the worker
    sample_batch[SampleBatch.VF_PREDS] = sample_batch['worker_values']
    sample_batch = compute_advantages(sample_batch, worker_last_r,
                                      policy.config['gamma'],
                                      policy.config['lambda'],
                                      policy.config['use_gae'],
                                      policy.config['use_critic'])
    sample_batch['worker_advantages'] = sample_batch[Postprocessing.ADVANTAGES]
    sample_batch['worker_value_targets'] = sample_batch[
        Postprocessing.VALUE_TARGETS]

    # WARNING: These values are only used temporarily. Do not use:
    # sample_batch[SampleBatch.VF_PREDS]
    # sample_batch[Postprocessing.ADVANTAGES]
    # sample_batch[Postprocessing.VALUE_TARGETS]

    return sample_batch
Example #14
0
def centralized_critic_postprocessing(policy,sample_batch,other_agent_batches=None, episode=None):
    """Adds the policy logits, VF preds, and advantages to the trajectory."""
    if policy.loss_initialized():
        assert other_agent_batches is not None
        # new_obs_array = np.array([])

        # Initiate frequently used values that stat consistent.
        idx_insert = np.arange(8, 8 + settings.n_neighbours * 3, 3)

        for idx_sample, value in enumerate(sample_batch[SampleBatch.INFOS]):
            if "sequence" in value:

                # Make new array containing all neighbouring AC for this current observation. This is passed along with
                # the info dict.
                # Delete the first element, as it contains the agent ID of the currenct agent considered.
                neighbours_ac = value['sequence'][1:]
                # idx_insert = np.arange(8, 8 + settings.n_neighbours * 3, 3)
                # print('AgentIDinbatch:', sample_batch[SampleBatch.AGENT_INDEX][idx_sample])
                # print('AgentIDinInfo', value['sequence'][0])
                # temp = sample_batch[SampleBatch.CUR_OBS][idx]

                # Create temporary array copies the current observations (without the opponent actions), and append
                # a 0 at the end to be able to use np.insert.
                temp = np.append(sample_batch[SampleBatch.CUR_OBS][idx_sample], np.float32(0))

                # Now retrieve opponent actions, by looping over all the neighbour_ac agent ID's in the opponent_batch
                # The sequence of neighbours_ac follows the same orderning used in the observation space.
                # So this sequence is used to correctly place the opponent action behind the observation.

                for idx_insert_idx, agentid in enumerate(neighbours_ac):
                    temp = np.insert(temp, idx_insert[idx_insert_idx], transform_action(other_agent_batches[agentid][1][SampleBatch.ACTIONS][idx_sample]))

                # New array contains as many actions as given by the amount of neighbours_ac, which is given by the
                # amount of current active AC in the simulation. So if the amount of neighbours would drop below the
                # value in the settings, the state space would not be the same. So padding is required.
                # Padding is done by comparing the required state space size with the current created.
                # Difference is padded with "fill"

                if len(temp[:-1]) < max(idx_insert+1):
                    fill = max(idx_insert+1) - len(temp[:-1])
                    fill_zero = np.full((1, fill), -1, dtype=np.float32)
                    temp = np.append(temp, fill_zero)

                # New_obs_list = np.append(New_obs_list, temp[:-1])

                # Delete the last element, which was the padded 0 for the np.insert.
                # New_obs_list.append(temp[:-1])

                # First sample should create a new array, rest should be appended. (remember, this is the full sample
                # batch). temp[:-1] is done to delete the additional 0 used to enable NP.INSERT to work.

                if idx_sample == 0:
                    new_obs_array = np.array([temp[:-1]])
                else:
                    new_obs_array = np.concatenate((new_obs_array, [temp[:-1]]), axis=0)
            else:
                # If sequence is not present in the info dict, this means that there is only 1 plane left.
                # create required padding and send as observation.
                fill = (6 + settings.n_neighbours * 3) - len(sample_batch[SampleBatch.CUR_OBS][idx_sample])
                fill_zero = np.full((1, fill), -1, dtype=np.float32)
                temp_2 = np.append(sample_batch[SampleBatch.CUR_OBS][idx_sample], fill_zero)

                if idx_sample == 0:
                    new_obs_array = temp_2
                else:
                    new_obs_array = np.concatenate((new_obs_array, [temp_2]), axis=0)

        # Add new observations including actions to sample batch.
        sample_batch[NEW_OBS_ACTION] = new_obs_array

        # Calculated the predicted value function, and include in the batch.
        sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(sample_batch[NEW_OBS_ACTION])

    else:
        # If policy is not initialized, create dummy batch.

        fake_size = 6 + settings.n_neighbours*3
        sample_batch[NEW_OBS_ACTION] = np.array([])
        sample_batch[NEW_OBS_ACTION] = np.zeros((1, fake_size), dtype=np.float32)
        sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
            sample_batch[SampleBatch.ACTIONS], dtype=np.float32)

    # Check if sample_batch is done to tidy up stuff.
    completed = sample_batch["dones"][-1]
    if completed:
        last_r = 0.0
    else:
        last_r = sample_batch[SampleBatch.VF_PREDS][-1]

    # Compute advantages using the new observations.
    batch = compute_advantages(
        sample_batch,
        last_r,
        policy.config["gamma"],
        policy.config["lambda"],
        use_gae=policy.config["use_gae"])

    return batch
Example #15
0
def postprocess_drq_ppo_gae(policy,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
    """Adds the policy logits, VF preds, and advantages to the trajectory.
    """
    completed = sample_batch["dones"][-1]
    batch_final = None

    for k in range(policy.config["aug_num"]):
        batch_copy = sample_batch.copy()

        if completed:
            last_r = 0.0
        else:
            next_state = []
            for i in range(policy.num_state_tensors()):
                next_state.append([sample_batch["state_out_{}".format(i)][-1]])

            # augmented last next obs (T,C,H,W) -> (H,W,C)
            nxt_obs = torch.as_tensor(
                sample_batch[SampleBatch.NEXT_OBS][-1:]).float()
            aug_next_obs = policy.model.trans(nxt_obs.permute(0, 3, 1,
                                                              2))[0].permute(
                                                                  1, 2, 0)

            # augmented all obs in episodes (T,C,H,W) -> (T,H,W,C)
            cur_obs = torch.as_tensor(
                sample_batch[SampleBatch.CUR_OBS]).float()
            aug_obs = policy.model.trans(cur_obs.permute(0, 3, 1, 2)).permute(
                0, 2, 3, 1)

            # vf preds on augmented cur obs
            with torch.no_grad():
                _, _ = policy.model({
                    "obs": aug_obs,
                    "is_training": False
                }, [], None)
                batch_copy[SampleBatch.VF_PREDS] = policy.model.value_function(
                ).numpy()

            # last reward using value on last next obs (augmented)
            last_r = policy._value(aug_next_obs.numpy(),
                                   sample_batch[SampleBatch.ACTIONS][-1],
                                   sample_batch[SampleBatch.REWARDS][-1],
                                   *next_state)

        aug_batch = compute_advantages(batch_copy,
                                       last_r,
                                       policy.config["gamma"],
                                       policy.config["lambda"],
                                       use_gae=policy.config["use_gae"])

        # collect necessary fields
        if batch_final is None:
            batch_final = sample_batch.copy()
            batch_final[Postprocessing.ADVANTAGES] = aug_batch[
                Postprocessing.ADVANTAGES]
            batch_final[Postprocessing.VALUE_TARGETS] = aug_batch[
                Postprocessing.VALUE_TARGETS]
        else:
            batch_final[Postprocessing.VALUE_TARGETS] += aug_batch[
                Postprocessing.VALUE_TARGETS]

    # get averaged V targets
    batch_final[Postprocessing.VALUE_TARGETS] /= policy.config["aug_num"]
    # NOTE: CUR_OBS is not augmented (augment them in loss function)
    return batch_final
Example #16
0
    def test_marwil_loss_function(self):
        """
        To generate the historic data used in this test case, first run:
        $ ./train.py --run=PPO --env=CartPole-v0 \
          --stop='{"timesteps_total": 50000}' \
          --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
        """
        rllib_dir = Path(__file__).parent.parent.parent.parent
        print("rllib dir={}".format(rllib_dir))
        data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
        print("data_file={} exists={}".format(data_file,
                                              os.path.isfile(data_file)))
        config = marwil.DEFAULT_CONFIG.copy()
        config["num_workers"] = 0  # Run locally.
        # Learn from offline data.
        config["input"] = [data_file]

        for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
            reader = JsonReader(inputs=[data_file])
            batch = reader.next()

            trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()
            model = policy.model

            # Calculate our own expected values (to then compare against the
            # agent's loss output).
            cummulative_rewards = compute_advantages(batch, 0.0,
                                                     config["gamma"], 1.0,
                                                     False,
                                                     False)["advantages"]
            if fw == "torch":
                cummulative_rewards = torch.tensor(cummulative_rewards)
            batch = policy._lazy_tensor_dict(batch)
            model_out, _ = model.from_batch(batch)
            vf_estimates = model.value_function()
            adv = cummulative_rewards - vf_estimates
            if fw == "torch":
                adv = adv.detach().cpu().numpy()
            adv_squared = np.mean(np.square(adv))
            c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
            c = np.sqrt(c_2)
            exp_advs = np.exp(config["beta"] * (adv / c))
            logp = policy.dist_class(model_out, model).logp(batch["actions"])
            if fw == "torch":
                logp = logp.detach().cpu().numpy()
            # Calculate all expected loss components.
            expected_vf_loss = 0.5 * adv_squared
            expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
            expected_loss = \
                expected_pol_loss + config["vf_coeff"] * expected_vf_loss

            # Calculate the algorithm's loss (to check against our own
            # calculation above).
            batch.set_get_interceptor(None)
            postprocessed_batch = policy.postprocess_trajectory(batch)
            loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \
                else marwil.marwil_torch_policy.marwil_loss
            loss_out = loss_func(policy, model, policy.dist_class,
                                 policy._lazy_tensor_dict(postprocessed_batch))

            # Check all components.
            if fw == "torch":
                check(policy.v_loss, expected_vf_loss, decimals=4)
                check(policy.p_loss, expected_pol_loss, decimals=4)
            else:
                check(policy.loss.v_loss, expected_vf_loss, decimals=4)
                check(policy.loss.p_loss, expected_pol_loss, decimals=4)
            check(loss_out, expected_loss, decimals=3)
Example #17
0
 def postprocess_trajectory(self,
                            batch,
                            other_agent_batches=None,
                            episode=None):
     assert episode is not None
     return compute_advantages(batch, 100.0, 0.9, use_gae=False)
def centralized_critic_postprocessing(policy,
                                      sample_batch,
                                      other_agent_batches=None,
                                      episode=None):
    if policy.loss_initialized():
        assert other_agent_batches is not None

        time_span = (sample_batch['t'][0], sample_batch['t'][-1])
        other_agent_times = {
            agent_id: (other_agent_batches[agent_id][1]["t"][0],
                       other_agent_batches[agent_id][1]["t"][-1])
            for agent_id in other_agent_batches.keys()
        }
        # find agents whose time overlaps with the current agent
        rel_agents = {
            agent_id: other_agent_time
            for agent_id, other_agent_time in other_agent_times.items()
            if time_overlap(time_span, other_agent_time)
        }
        if len(rel_agents) > 0:
            other_obs = {
                agent_id: other_agent_batches[agent_id][1]["obs"].copy()
                for agent_id in rel_agents.keys()
            }
            padded_agent_obs = {
                agent_id: overlap_and_pad_agent(time_span, rel_agent_time,
                                                other_obs[agent_id])
                for agent_id, rel_agent_time in rel_agents.items()
            }
            central_obs_batch = np.hstack(
                [padded_obs for padded_obs in padded_agent_obs.values()])
            central_obs_batch = np.hstack(
                (central_obs_batch, sample_batch["obs"]))
        else:
            central_obs_batch = sample_batch["obs"]
        max_vf_agents = policy.model.max_num_agents
        num_agents = len(rel_agents) + 1
        if num_agents < max_vf_agents:
            diff = max_vf_agents - num_agents
            zero_pad = np.zeros((central_obs_batch.shape[0],
                                 policy.model.obs_space_shape * diff))
            central_obs_batch = np.hstack((central_obs_batch, zero_pad))
        elif num_agents > max_vf_agents:
            print("Too many agents!")

        # also record the opponent obs and actions in the trajectory
        sample_batch[OPPONENT_OBS] = central_obs_batch

        # overwrite default VF prediction with the central VF
        sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
            sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS])
    else:
        # policy hasn't initialized yet, use zeros
        #TODO(evinitsky) put in the right shape
        obs_shape = sample_batch[SampleBatch.CUR_OBS].shape[1]
        obs_shape = (1, obs_shape * (policy.model.max_num_agents))
        sample_batch[OPPONENT_OBS] = np.zeros(obs_shape)
        # TODO(evinitsky) put in the right shape. Will break if actions aren't 1
        sample_batch[SampleBatch.VF_PREDS] = np.zeros(1, dtype=np.float32)

    train_batch = compute_advantages(sample_batch,
                                     0.0,
                                     policy.config["gamma"],
                                     policy.config["lambda"],
                                     use_gae=policy.config["use_gae"])
    return train_batch
Example #19
0
 def postprocess_trajectory(self,
                            batch,
                            other_agent_batches=None,
                            episode=None):
     assert episode is not None
     return compute_advantages(batch, 100.0, 0.9, use_gae=False)
Example #20
0
def postprocess_trajectory(policy: TFPolicy,
                           sample_batch: SampleBatch,
                           other_agent_batches=None,
                           episode=None):
    last_r = 0.0
    batch_length = len(sample_batch[SampleBatch.CUR_OBS])
    var_names = policy.get_pure_var_names()
    other_gradients = {k: [] for k in var_names}
    other_vars = {k: [] for k in var_names}

    if policy.loss_initialized():
        for other_id, (other_policy, batch) in other_agent_batches.items():
            assert isinstance(other_policy, TFPolicy)
            assert isinstance(batch, SampleBatch)

            grads, vars = (
                other_policy.gamma_grads_ndarray_dict,
                other_policy.vars_ndarray_dict,
            )
            for k in grads:
                var = vars[k]
                grad = grads[k]
                name = name_ref(k)

                assert var.shape == grad.shape, (k, var.shape, grad.shape)
                other_gradients[name].append(grad / len(other_agent_batches))
                other_vars[name].append(var / len(other_agent_batches))

        for v in other_vars.values():
            assert len(v) == len(other_agent_batches), (
                len(v),
                len(other_agent_batches),
            )

            # pack other_gradients / other_vars as ndarray objects
        for name, grad_nested_list in other_gradients.items():
            assert len(other_vars[name]) > 0, name
            assert len(grad_nested_list) > 0, name
            var_nested = np.sum(other_vars[name], axis=0)
            grad_nested = np.sum(grad_nested_list, axis=0)

            assert var_nested.shape == other_vars[name][0].shape, (
                var_nested.shape,
                other_vars[name][0].shape,
            )

            assert grad_nested.shape == var_nested.shape, (
                grad_nested.shape,
                var_nested.shape,
            )
            reshape = (batch_length, ) + tuple([1] * len(grad_nested.shape))

            sample_batch[f"gamma_{name}"] = np.tile(grad_nested, reshape)
            sample_batch[f"var_{name}"] = np.tile(var_nested, reshape)
            assert (sample_batch[f"gamma_{name}"].shape ==
                    sample_batch[f"var_{name}"].shape), (
                        sample_batch[f"gamma_{name}"].shape,
                        sample_batch[f"var_{name}"].shape)
    else:
        grads_and_vars = policy.init_grads_and_vars()
        for k, (grad, var) in grads_and_vars.items():
            name = name_ref(k)
            assert other_gradients.get(name, None) is not None, name
            other_gradients[name].append(grad)
            other_vars[name].append(var)
            sample_batch[f"gamma_{name}"] = np.zeros((batch_length, ) +
                                                     grad.shape)
            sample_batch[f"var_{name}"] = np.tile(var, (batch_length, ) +
                                                  tuple([1] * len(var.shape)))
            assert (sample_batch[f"gamma_{name}"].shape ==
                    sample_batch[f"var_{name}"].shape), (
                        sample_batch[f"gamma_{name}"].shape,
                        sample_batch[f"var_{name}"].shape)

    train_batch = compute_advantages(
        sample_batch,
        last_r,
        policy.config.get("gamma", 0.9),
        policy.config.get("lambda", 1.0),
        policy.config.get("use_gae", False),
        policy.config.get("use_critic", False),
    )

    return train_batch
Example #21
0
 def postprocess_trajectory(self,
                            sample_batch,
                            other_agent_batches=None,
                            episode=None):
     return compute_advantages(
         sample_batch, 0.0, self.config["gamma"], use_gae=False)
Example #22
0
def postprocess_ppo_gae(
        policy: Policy,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
        episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
    """Postprocesses a trajectory and returns the processed trajectory.

    The trajectory contains only data from one episode and from one agent.
    - If  `config.batch_mode=truncate_episodes` (default), sample_batch may
    contain a truncated (at-the-end) episode, in case the
    `config.rollout_fragment_length` was reached by the sampler.
    - If `config.batch_mode=complete_episodes`, sample_batch will contain
    exactly one episode (no matter how long).
    New columns can be added to sample_batch and existing ones may be altered.

    Args:
        policy (Policy): The Policy used to generate the trajectory
            (`sample_batch`)
        sample_batch (SampleBatch): The SampleBatch to postprocess.
        other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
            dict of AgentIDs mapping to other agents' trajectory data (from the
            same episode). NOTE: The other agents use the same policy.
        episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
            object in which the agents operated.

    Returns:
        SampleBatch: The postprocessed, modified SampleBatch (or a new one).
    """

    # Trajectory is actually complete -> last r=0.0.
    if sample_batch[SampleBatch.DONES][-1]:
        last_r = 0.0
    # Trajectory has been truncated -> last r=VF estimate of last obs.
    else:
        # Input dict is provided to us automatically via the Model's
        # requirements. It's a single-timestep (last one in trajectory)
        # input_dict.
        if policy.config["_use_trajectory_view_api"]:
            # Create an input dict according to the Model's requirements.
            input_dict = policy.model.get_input_dict(sample_batch,
                                                     index="last")
            last_r = policy._value(**input_dict)
        # TODO: (sven) Remove once trajectory view API is all-algo default.
        else:
            next_state = []
            for i in range(policy.num_state_tensors()):
                next_state.append(sample_batch["state_out_{}".format(i)][-1])
            last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
                                   sample_batch[SampleBatch.ACTIONS][-1],
                                   sample_batch[SampleBatch.REWARDS][-1],
                                   *next_state)

    # Adds the policy logits, VF preds, and advantages to the batch,
    # using GAE ("generalized advantage estimation") or not.
    batch = compute_advantages(sample_batch,
                               last_r,
                               policy.config["gamma"],
                               policy.config["lambda"],
                               use_gae=policy.config["use_gae"],
                               use_critic=policy.config.get(
                                   "use_critic", True))

    return batch
Example #23
0
 def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
     last_r = 0.0
     batch = compute_advantages(
         sample_batch, last_r, self.config["gamma"],
         self.config["lambda"], use_gae=self.config["use_gae"])
     return batch
def compute_gae_for_sample_batch(
        policy: Policy,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
        episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
    """Adds GAE (generalized advantage estimations) to a trajectory.
    The trajectory contains only data from one episode and from one agent.
    - If  `config.batch_mode=truncate_episodes` (default), sample_batch may
    contain a truncated (at-the-end) episode, in case the
    `config.rollout_fragment_length` was reached by the sampler.
    - If `config.batch_mode=complete_episodes`, sample_batch will contain
    exactly one episode (no matter how long).
    New columns can be added to sample_batch and existing ones may be altered.
    Args:
        policy (Policy): The Policy used to generate the trajectory
            (`sample_batch`)
        sample_batch (SampleBatch): The SampleBatch to postprocess.
        other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
            dict of AgentIDs mapping to other agents' trajectory data (from the
            same episode). NOTE: The other agents use the same policy.
        episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
            object in which the agents operated.
    Returns:
        SampleBatch: The postprocessed, modified SampleBatch (or a new one).
    """

    # the trajectory view API will pass populate the info dict with a np.zeros((n,))
    # array in the first call, in that case the dtype will be float32 and we
    # have to ignore it. For regular calls, we extract the rewards from the info
    # dict into the samplebatch_infos_rewards dict, which now holds the rewards
    # for all agents as dict.
    samplebatch_infos_rewards = {'0': sample_batch[SampleBatch.INFOS]}
    if not sample_batch[SampleBatch.INFOS].dtype == "float32":
        samplebatch_infos = SampleBatch.concat_samples([
            SampleBatch({k: [v]
                         for k, v in s.items()})
            for s in sample_batch[SampleBatch.INFOS]
        ])
        samplebatch_infos_rewards = SampleBatch.concat_samples([
            SampleBatch({str(k): [v]
                         for k, v in s.items()})
            for s in samplebatch_infos["rewards"]
        ])

    if not isinstance(policy.action_space, gym.spaces.tuple.Tuple):
        raise InvalidActionSpace("Expect tuple action space")

    # samplebatches for each agents
    batches = []
    for key, action_space in zip(samplebatch_infos_rewards.keys(),
                                 policy.action_space):
        i = int(key)
        sample_batch_agent = sample_batch.copy()
        sample_batch_agent[SampleBatch.REWARDS] = (
            samplebatch_infos_rewards[key])
        if isinstance(action_space, gym.spaces.box.Box):
            assert len(action_space.shape) == 1
            a_w = action_space.shape[0]
        elif isinstance(action_space, gym.spaces.discrete.Discrete):
            a_w = 1
        else:
            raise InvalidActionSpace(
                "Expect gym.spaces.box or gym.spaces.discrete action space")

        sample_batch_agent[SampleBatch.ACTIONS] = sample_batch[
            SampleBatch.ACTIONS][:, a_w * i:a_w * (i + 1)]
        sample_batch_agent[SampleBatch.VF_PREDS] = sample_batch[
            SampleBatch.VF_PREDS][:, i]

        # Trajectory is actually complete -> last r=0.0.
        if sample_batch[SampleBatch.DONES][-1]:
            last_r = 0.0
        # Trajectory has been truncated -> last r=VF estimate of last obs.
        else:
            # Input dict is provided to us automatically via the Model's
            # requirements. It's a single-timestep (last one in trajectory)
            # input_dict.
            # Create an input dict according to the Model's requirements.
            input_dict = policy.model.get_input_dict(sample_batch,
                                                     index="last")
            all_values = policy._value(**input_dict,
                                       seq_lens=input_dict.seq_lens)
            last_r = all_values[i].item()

        # Adds the policy logits, VF preds, and advantages to the batch,
        # using GAE ("generalized advantage estimation") or not.
        batches.append(
            compute_advantages(sample_batch_agent,
                               last_r,
                               policy.config["gamma"],
                               policy.config["lambda"],
                               use_gae=policy.config["use_gae"],
                               use_critic=policy.config.get(
                                   "use_critic", True)))

    # Now take original samplebatch and overwrite following elements as a concatenation of these
    for k in [
            SampleBatch.REWARDS,
            SampleBatch.VF_PREDS,
            Postprocessing.ADVANTAGES,
            Postprocessing.VALUE_TARGETS,
    ]:
        sample_batch[k] = np.stack([b[k] for b in batches], axis=-1)

    return sample_batch
Example #25
0
 def postprocess_trajectory(self, batch, other_agent_batches=None):
     return compute_advantages(batch, 100.0, 0.9, use_gae=False)
Example #26
0
def post_process_advantages(policy, sample_batch, other_agent_batches=None,
                            episode=None):
    """This adds the "advantages" column to the sample train_batch."""
    return compute_advantages(sample_batch, 0.0, policy.config["gamma"],
                              use_gae=False)
Example #27
0
def centralized_critic_postprocessing(policy,
                                      sample_batch,
                                      other_agent_batches=None,
                                      episode=None):
    # one hot encoding parser
    one_hot_enc = functools.partial(one_hot_encoding,
                                    n_classes=policy.action_space.n)
    max_num_opponents = policy.model.max_num_opponents

    if policy.loss_initialized():
        assert other_agent_batches is not None

        if len(other_agent_batches) > max_num_opponents:
            raise ValueError(
                "The number of opponents is too large, got {} (max at {})".
                format(len(other_agent_batches), max_num_opponents))

        # lifespan of the agents
        time_span = (sample_batch["t"][0], sample_batch["t"][-1])

        # agents whose time overlaps with the current agent time_span
        # returns agent_id: agent_time_span, opp_sample_batch
        opponents = [
            Opponent(
                (opp_batch["t"][0], opp_batch["t"][-1]),
                opp_batch[SampleBatch.CUR_OBS],
                one_hot_enc(opp_batch[SampleBatch.ACTIONS]),
            ) for agent_id, (_, opp_batch) in other_agent_batches.items()
            if time_overlap(time_span, (opp_batch["t"][0], opp_batch["t"][-1]))
        ]

        # apply the adequate cropping or padding compared to time_span
        for opp in opponents:
            opp.crop_or_pad(time_span)

        # add a padding for the missing opponents
        missing_opponent = Opponent(
            None,
            np.zeros_like(sample_batch[SampleBatch.CUR_OBS]),
            one_hot_enc(np.zeros_like(sample_batch[SampleBatch.ACTIONS])),
        )
        opponents = opponents + ([missing_opponent] *
                                 (max_num_opponents - len(opponents)))

        # add random permutation of the opponents
        perm = np.random.permutation(np.arange(max_num_opponents))
        opponents = [opponents[p] for p in perm]

        # add the oppponents' information into sample_batch
        sample_batch[OTHER_AGENT] = np.concatenate(
            [opp.observation
             for opp in opponents] + [opp.action for opp in opponents],
            axis=-1,
        )
        # overwrite default VF prediction with the central VF
        sample_batch[
            SampleBatch.VF_PREDS] = policy.compute_central_value_function(
                sample_batch[SampleBatch.CUR_OBS], sample_batch[OTHER_AGENT])

    else:

        # opponents' observation placeholder
        missing_obs = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
        missing_act = one_hot_enc(
            np.zeros_like(sample_batch[SampleBatch.ACTIONS]))
        sample_batch[OTHER_AGENT] = np.concatenate(
            [missing_obs for _ in range(max_num_opponents)] +
            [missing_act for _ in range(max_num_opponents)],
            axis=-1,
        )

        # value prediction
        sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
            sample_batch[SampleBatch.ACTIONS], dtype=np.float32)

    train_batch = compute_advantages(
        sample_batch,
        0.0,
        policy.config["gamma"],
        policy.config["lambda"],
        use_gae=policy.config["use_gae"],
    )
    return train_batch
Example #28
0
def postprocess_advantages(policy,
                           sample_batch,
                           other_agent_batches=None,
                           episode=None):
    """Postprocesses a trajectory and returns the processed trajectory.

    The trajectory contains only data from one episode and from one agent.
    - If  `config.batch_mode=truncate_episodes` (default), sample_batch may
    contain a truncated (at-the-end) episode, in case the
    `config.rollout_fragment_length` was reached by the sampler.
    - If `config.batch_mode=complete_episodes`, sample_batch will contain
    exactly one episode (no matter how long).
    New columns can be added to sample_batch and existing ones may be altered.

    Args:
        policy (Policy): The Policy used to generate the trajectory
            (`sample_batch`)
        sample_batch (SampleBatch): The SampleBatch to postprocess.
        other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
            dict of AgentIDs mapping to other agents' trajectory data (from the
            same episode). NOTE: The other agents use the same policy.
        episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
            object in which the agents operated.

    Returns:
        SampleBatch: The postprocessed, modified SampleBatch (or a new one).
    """

    # Trajectory is actually complete -> last r=0.0.
    if sample_batch[SampleBatch.DONES][-1]:
        last_r = 0.0
    # Trajectory has been truncated -> last r=VF estimate of last obs.
    else:
        # Input dict is provided to us automatically via the Model's
        # requirements. It's a single-timestep (last one in trajectory)
        # input_dict.
        if policy.config["_use_trajectory_view_api"]:
            # Create an input dict according to the Model's requirements.
            index = "last" if SampleBatch.NEXT_OBS in sample_batch.data else -1
            input_dict = policy.model.get_input_dict(sample_batch, index=index)
            last_r = policy._value(**input_dict)
        # TODO: (sven) Remove once trajectory view API is all-algo default.
        else:
            next_state = []
            for i in range(policy.num_state_tensors()):
                next_state.append(sample_batch["state_out_{}".format(i)][-1])
            last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
                                   sample_batch[SampleBatch.ACTIONS][-1],
                                   sample_batch[SampleBatch.REWARDS][-1],
                                   *next_state)

    # Adds the "advantages" (which in the case of MARWIL are simply the
    # discounted cummulative rewards) to the SampleBatch.
    return compute_advantages(
        sample_batch,
        last_r,
        policy.config["gamma"],
        # We just want the discounted cummulative rewards, so we won't need
        # GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
        use_gae=False,
        use_critic=False)
Example #29
0
def postprocess_advantages(policy,
                           sample_batch,
                           other_agent_batches=None,
                           episode=None):
    return compute_advantages(
        sample_batch, 0.0, policy.config["gamma"], use_gae=False)
Example #30
0
    def test_marwil_loss_function(self):
        """
        To generate the historic data used in this test case, first run:
        $ ./train.py --run=PPO --env=CartPole-v0 \
          --stop='{"timesteps_total": 50000}' \
          --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
        """
        rllib_dir = Path(__file__).parent.parent.parent.parent
        print("rllib dir={}".format(rllib_dir))
        data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
        print("data_file={} exists={}".format(data_file,
                                              os.path.isfile(data_file)))

        config = (marwil.MARWILConfig().rollouts(
            num_rollout_workers=0).offline_data(input_=[data_file])
                  )  # Learn from offline data.

        for fw, sess in framework_iterator(config, session=True):
            reader = JsonReader(inputs=[data_file])
            batch = reader.next()

            trainer = config.build(env="CartPole-v0")
            policy = trainer.get_policy()
            model = policy.model

            # Calculate our own expected values (to then compare against the
            # agent's loss output).
            cummulative_rewards = compute_advantages(batch, 0.0, config.gamma,
                                                     1.0, False,
                                                     False)["advantages"]
            if fw == "torch":
                cummulative_rewards = torch.tensor(cummulative_rewards)
            if fw != "tf":
                batch = policy._lazy_tensor_dict(batch)
            model_out, _ = model(batch)
            vf_estimates = model.value_function()
            if fw == "tf":
                model_out, vf_estimates = policy.get_session().run(
                    [model_out, vf_estimates])
            adv = cummulative_rewards - vf_estimates
            if fw == "torch":
                adv = adv.detach().cpu().numpy()
            adv_squared = np.mean(np.square(adv))
            c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
            c = np.sqrt(c_2)
            exp_advs = np.exp(config.beta * (adv / c))
            dist = policy.dist_class(model_out, model)
            logp = dist.logp(batch["actions"])
            if fw == "torch":
                logp = logp.detach().cpu().numpy()
            elif fw == "tf":
                logp = sess.run(logp)
            # Calculate all expected loss components.
            expected_vf_loss = 0.5 * adv_squared
            expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
            expected_loss = expected_pol_loss + config.vf_coeff * expected_vf_loss

            # Calculate the algorithm's loss (to check against our own
            # calculation above).
            batch.set_get_interceptor(None)
            postprocessed_batch = policy.postprocess_trajectory(batch)
            loss_func = (MARWILTF2Policy.loss
                         if fw != "torch" else MARWILTorchPolicy.loss)
            if fw != "tf":
                policy._lazy_tensor_dict(postprocessed_batch)
                loss_out = loss_func(policy, model, policy.dist_class,
                                     postprocessed_batch)
            else:
                loss_out, v_loss, p_loss = policy.get_session().run(
                    # policy._loss is create by TFPolicy, and is basically the
                    # loss tensor of the static graph.
                    [
                        policy._loss,
                        policy._marwil_loss.v_loss,
                        policy._marwil_loss.p_loss,
                    ],
                    feed_dict=policy._get_loss_inputs_dict(postprocessed_batch,
                                                           shuffle=False),
                )

            # Check all components.
            if fw == "torch":
                check(policy.v_loss, expected_vf_loss, decimals=4)
                check(policy.p_loss, expected_pol_loss, decimals=4)
            elif fw == "tf":
                check(v_loss, expected_vf_loss, decimals=4)
                check(p_loss, expected_pol_loss, decimals=4)
            else:
                check(policy._marwil_loss.v_loss, expected_vf_loss, decimals=4)
                check(policy._marwil_loss.p_loss,
                      expected_pol_loss,
                      decimals=4)
            check(loss_out, expected_loss, decimals=3)
Example #31
0
def postprocess_trajectory(policy: TFPolicy,
                           sample_batch: SampleBatch,
                           other_agent_batches=None,
                           episode=None):
    last_r = 0.0
    batch_length = len(sample_batch[SampleBatch.CUR_OBS])
    action_preprocessor = policy.model.act_preprocessor
    obs_preprocessor = policy.model.obs_preprocessor

    mean_action = np.zeros((batch_length, ) + action_preprocessor.shape)
    own_action = np.zeros((batch_length, ) + action_preprocessor.shape)
    own_obs = np.zeros((batch_length, ) + obs_preprocessor.shape)

    if policy.loss_initialized():
        sample_batch[SampleBatch.DONES][-1] = 1
        # ordered by agent keys
        other_agent_batches = OrderedDict(other_agent_batches)
        for i, (other_id, (other_policy,
                           batch)) in enumerate(other_agent_batches.items()):
            copy_length = min(batch_length,
                              batch[SampleBatch.CUR_OBS].shape[0])

            # TODO(ming): check the action type
            if isinstance(policy.action_space, spaces.Discrete):
                buffer_action = np.eye(action_preprocessor.size)[batch[
                    SampleBatch.ACTIONS][:copy_length]]
            elif isinstance(policy.action_space, spaces.Box):
                buffer_action = batch[SampleBatch.ACTIONS][:copy_length]
            else:
                raise NotImplementedError(
                    f"Do not support such an action space yet:{type(policy.action_space)}"
                )

            mean_action[:copy_length] += buffer_action
        # fill my features to critic_obs_array
        if isinstance(policy.action_space, spaces.Box):
            buffer_action = sample_batch[SampleBatch.ACTIONS]
        elif isinstance(policy.action_space, spaces.Discrete):
            buffer_action = np.eye(action_preprocessor.size)[sample_batch[
                SampleBatch.ACTIONS][:batch_length]]
        else:
            raise NotImplementedError(
                f"Do not support such an action space yte: {type(policy.action_space)}"
            )
        own_action[:batch_length] = buffer_action
        own_obs[:] = sample_batch[SampleBatch.CUR_OBS]
        mean_action /= max(1, len(other_agent_batches))

        sample_batch[CentralizedActorCriticModel.CRITIC_OBS] = np.concatenate(
            [own_obs, own_action, mean_action], axis=-1)
        sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
            sample_batch[CentralizedActorCriticModel.CRITIC_OBS])
    else:
        sample_batch[CentralizedActorCriticModel.CRITIC_OBS] = np.concatenate(
            [own_obs, own_action, mean_action], axis=-1)
        sample_batch[SampleBatch.VF_PREDS] = np.zeros_like((batch_length, ),
                                                           dtype=np.float32)

    train_batch = compute_advantages(
        sample_batch,
        last_r,
        policy.config["gamma"],
        policy.config["lambda"],
        policy.config["use_gae"],
    )
    return train_batch