def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
        """Returns log likelihood for actions in given batch for policy.

        Computes likelihoods by passing the observations through the current
        policy's `compute_log_likelihoods()` method

        Args:
            batch: The SampleBatch or MultiAgentBatch to calculate action
                log likelihoods from. This batch/batches must contain OBS
                and ACTIONS keys.

        Returns:
            The probabilities of the actions in the batch, given the
            observations and the policy.
        """
        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
        log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
            actions_normalized=True,
        )
        log_likelihoods = convert_to_numpy(log_likelihoods)
        return log_likelihoods
Exemple #2
0
    def split_train_val(self, samples: SampleBatchType):
        dataset_size = samples.count
        indices = np.arange(dataset_size)
        np.random.shuffle(indices)
        split_idx = int(dataset_size * (1 - self.valid_split))
        idx_train = indices[:split_idx]
        idx_test = indices[split_idx:]

        train = {}
        val = {}
        for key in samples.keys():
            train[key] = samples[key][idx_train, :]
            val[key] = samples[key][idx_test, :]
        return SampleBatch(train), SampleBatch(val)
    def action_prob(self, batch: SampleBatchType) -> TensorType:
        """Returns the probs for the batch actions for the current policy."""

        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
        log_likelihoods = self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.CUR_OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.data.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.data.get(SampleBatch.PREV_REWARDS))
        return log_likelihoods
Exemple #4
0
    def action_prob(self, batch: SampleBatchType) -> np.ndarray:
        """Returns the probs for the batch actions for the current policy."""

        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
        log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.CUR_OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
            actions_normalized=True,
        )
        log_likelihoods = convert_to_numpy(log_likelihoods)
        return np.exp(log_likelihoods)