예제 #1
0
def test_random_shooting_with_dynamic_step_driver(observation_space, action_space):
    """
    This test uses the environment wrapper as an adapter so that a driver from TF-Agents can be used
    to generate a rollout. This also serves as an example of how to construct "random shooting"
    rollouts from an environment model.

    The assertion in this test is that selected action has the expected log_prob value consistent
    with optimisers from a uniform distribution. All this is really checking is that the preceeding
    code has run successfully.
    """

    network = LinearTransitionNetwork(observation_space)
    environment = KerasTransitionModel([network], observation_space, action_space)
    wrapped_environment = EnvironmentModel(
        environment,
        ConstantReward(observation_space, action_space, 0.0),
        ConstantFalseTermination(observation_space),
        create_uniform_initial_state_distribution(observation_space),
    )

    random_policy = RandomTFPolicy(
        wrapped_environment.time_step_spec(), action_space, emit_log_probability=True
    )

    transition_observer = _RecordLastLogProbTransitionObserver()

    driver = DynamicStepDriver(
        env=wrapped_environment,
        policy=random_policy,
        transition_observers=[transition_observer],
    )
    driver.run()

    last_log_prob = transition_observer.last_log_probability

    uniform_distribution = create_uniform_distribution_from_spec(action_space)
    action_log_prob = uniform_distribution.log_prob(transition_observer.action)
    expected = np.sum(action_log_prob.numpy().astype(np.float32))
    actual = np.sum(last_log_prob.numpy())

    np.testing.assert_array_almost_equal(actual, expected, decimal=4)
예제 #2
0
    def __init__(
        self,
        environment_model: EnvironmentModel,
        trajectory_optimiser: TrajectoryOptimiser,
        clip: bool = True,
        emit_log_probability: bool = False,
        automatic_state_reset: bool = True,
        observation_and_action_constraint_splitter: Optional[
            types.Splitter] = None,
        validate_args: bool = True,
        name: Optional[Text] = None,
    ):
        """
        Initializes the class.

        :param environment_model: An `EnvironmentModel` is a model of the MDP that represents
            the environment, consisting of a transition, reward, termination and initial state
            distribution model, of which some are trainable and some are fixed.
        :param trajectory_optimiser: A `TrajectoryOptimiser` takes an environment model and
            optimises a sequence of actions over a given horizon using virtual rollouts.
        :param clip: Whether to clip actions to spec before returning them. By default True.
        :param emit_log_probability: Emit log-probabilities of actions, if supported. If
            True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
            Please consult utility methods provided in policy_step for setting and
            retrieving these. When working with custom policies, either provide a
            dictionary info_spec or a namedtuple with the field 'log_probability'.
        :param automatic_state_reset:  If `True`, then `get_initial_policy_state` is used
            to clear state in `action()` and `distribution()` for for time steps
            where `time_step.is_first()`.
        :param observation_and_action_constraint_splitter: A function used to process
            observations with action constraints. These constraints can indicate,
            for example, a mask of valid/invalid actions for a given state of the
            environment. The function takes in a full observation and returns a
            tuple consisting of 1) the part of the observation intended as input to
            the network and 2) the constraint. An example
            `observation_and_action_constraint_splitter` could be as simple as: ```
            def observation_and_action_constraint_splitter(observation): return
              observation['network_input'], observation['constraint'] ```
            *Note*: when using `observation_and_action_constraint_splitter`, make
              sure the provided `q_network` is compatible with the network-specific
              half of the output of the
              `observation_and_action_constraint_splitter`. In particular,
              `observation_and_action_constraint_splitter` will be called on the
              observation before passing to the network. If
              `observation_and_action_constraint_splitter` is None, action
              constraints are not applied.
        :param validate_args: Python bool.  Whether to verify inputs to, and outputs of,
            functions like `action` and `distribution` against spec structures,
            dtypes, and shapes. Research code may prefer to set this value to `False`
            to allow iterating on input and output structures without being hamstrung
            by overly rigid checking (at the cost of harder-to-debug errors). See also
            `TFAgent.validate_args`.
        :param name: A name for this module. Defaults to the class name.
        """

        self.trajectory_optimiser = trajectory_optimiser
        self._environment_model = environment_model

        # making sure the batch_size in environment_model is correctly set
        self._environment_model.batch_size = self.trajectory_optimiser.batch_size

        super(PlanningPolicy, self).__init__(
            time_step_spec=environment_model.time_step_spec(),
            action_spec=environment_model.action_spec(),
            policy_state_spec=(),
            info_spec=(),
            clip=clip,
            emit_log_probability=emit_log_probability,
            automatic_state_reset=automatic_state_reset,
            observation_and_action_constraint_splitter=
            observation_and_action_constraint_splitter,
            validate_args=validate_args,
            name=name,
        )