def postprocess_trajectory( policy: TFPolicy, sample_batch: SampleBatch, other_agent_batches=None, episode=None ): last_r = 0.0 batch_length = len(sample_batch[SampleBatch.CUR_OBS]) critic_preprocessor = policy.model.critic_preprocessor action_preprocessor = policy.model.act_preprocessor obs_preprocessor = policy.model.obs_preprocessor critic_obs_array = np.zeros((batch_length,) + critic_preprocessor.shape) offset_slot = action_preprocessor.size + obs_preprocessor.size if policy.loss_initialized(): # ordered by agent keys other_agent_batches = OrderedDict(other_agent_batches) for i, (other_id, (other_policy, batch)) in enumerate( other_agent_batches.items() ): offset = (i + 1) * offset_slot copy_length = min(batch_length, batch[SampleBatch.CUR_OBS].shape[0]) # TODO(ming): check the action type buffer_action = get_action_buffer( policy.action_space, action_preprocessor, batch, copy_length ) oppo_features = np.concatenate( [batch[SampleBatch.CUR_OBS][:copy_length], buffer_action], axis=-1 ) assert oppo_features.shape[-1] == offset_slot critic_obs_array[ :copy_length, offset : offset + offset_slot ] = oppo_features # fill my features to critic_obs_array buffer_action = get_action_buffer( policy.action_space, action_preprocessor, sample_batch, batch_length ) critic_obs_array[:batch_length, 0:offset_slot] = np.concatenate( [sample_batch[SampleBatch.CUR_OBS], buffer_action], axis=-1 ) sample_batch[CentralizedActorCriticModel.CRITIC_OBS] = critic_obs_array sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( sample_batch[CentralizedActorCriticModel.CRITIC_OBS] ) else: sample_batch[CentralizedActorCriticModel.CRITIC_OBS] = critic_obs_array sample_batch[SampleBatch.VF_PREDS] = np.zeros_like( (batch_length,), dtype=np.float32 ) train_batch = compute_advantages( sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], policy.config["use_gae"], ) return train_batch
def postprocess_trajectory(policy: TFPolicy, sample_batch: SampleBatch, other_agent_batches=None, episode=None): last_r = 0.0 batch_length = len(sample_batch[SampleBatch.CUR_OBS]) action_preprocessor = policy.model.act_preprocessor obs_preprocessor = policy.model.obs_preprocessor mean_action = np.zeros((batch_length, ) + action_preprocessor.shape) own_action = np.zeros((batch_length, ) + action_preprocessor.shape) own_obs = np.zeros((batch_length, ) + obs_preprocessor.shape) if policy.loss_initialized(): sample_batch[SampleBatch.DONES][-1] = 1 # ordered by agent keys other_agent_batches = OrderedDict(other_agent_batches) for i, (other_id, (other_policy, batch)) in enumerate(other_agent_batches.items()): copy_length = min(batch_length, batch[SampleBatch.CUR_OBS].shape[0]) # TODO(ming): check the action type if isinstance(policy.action_space, spaces.Discrete): buffer_action = np.eye(action_preprocessor.size)[batch[ SampleBatch.ACTIONS][:copy_length]] elif isinstance(policy.action_space, spaces.Box): buffer_action = batch[SampleBatch.ACTIONS][:copy_length] else: raise NotImplementedError( f"Do not support such an action space yet:{type(policy.action_space)}" ) mean_action[:copy_length] += buffer_action # fill my features to critic_obs_array if isinstance(policy.action_space, spaces.Box): buffer_action = sample_batch[SampleBatch.ACTIONS] elif isinstance(policy.action_space, spaces.Discrete): buffer_action = np.eye(action_preprocessor.size)[sample_batch[ SampleBatch.ACTIONS][:batch_length]] else: raise NotImplementedError( f"Do not support such an action space yte: {type(policy.action_space)}" ) own_action[:batch_length] = buffer_action own_obs[:] = sample_batch[SampleBatch.CUR_OBS] mean_action /= max(1, len(other_agent_batches)) sample_batch[CentralizedActorCriticModel.CRITIC_OBS] = np.concatenate( [own_obs, own_action, mean_action], axis=-1) sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( sample_batch[CentralizedActorCriticModel.CRITIC_OBS]) else: sample_batch[CentralizedActorCriticModel.CRITIC_OBS] = np.concatenate( [own_obs, own_action, mean_action], axis=-1) sample_batch[SampleBatch.VF_PREDS] = np.zeros_like((batch_length, ), dtype=np.float32) train_batch = compute_advantages( sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], policy.config["use_gae"], ) return train_batch