def test_random_shooting_with_dynamic_step_driver(observation_space, action_space): """ This test uses the environment wrapper as an adapter so that a driver from TF-Agents can be used to generate a rollout. This also serves as an example of how to construct "random shooting" rollouts from an environment model. The assertion in this test is that selected action has the expected log_prob value consistent with optimisers from a uniform distribution. All this is really checking is that the preceeding code has run successfully. """ network = LinearTransitionNetwork(observation_space) environment = KerasTransitionModel([network], observation_space, action_space) wrapped_environment = EnvironmentModel( environment, ConstantReward(observation_space, action_space, 0.0), ConstantFalseTermination(observation_space), create_uniform_initial_state_distribution(observation_space), ) random_policy = RandomTFPolicy( wrapped_environment.time_step_spec(), action_space, emit_log_probability=True ) transition_observer = _RecordLastLogProbTransitionObserver() driver = DynamicStepDriver( env=wrapped_environment, policy=random_policy, transition_observers=[transition_observer], ) driver.run() last_log_prob = transition_observer.last_log_probability uniform_distribution = create_uniform_distribution_from_spec(action_space) action_log_prob = uniform_distribution.log_prob(transition_observer.action) expected = np.sum(action_log_prob.numpy().astype(np.float32)) actual = np.sum(last_log_prob.numpy()) np.testing.assert_array_almost_equal(actual, expected, decimal=4)
def __init__( self, environment_model: EnvironmentModel, trajectory_optimiser: TrajectoryOptimiser, clip: bool = True, emit_log_probability: bool = False, automatic_state_reset: bool = True, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None, validate_args: bool = True, name: Optional[Text] = None, ): """ Initializes the class. :param environment_model: An `EnvironmentModel` is a model of the MDP that represents the environment, consisting of a transition, reward, termination and initial state distribution model, of which some are trainable and some are fixed. :param trajectory_optimiser: A `TrajectoryOptimiser` takes an environment model and optimises a sequence of actions over a given horizon using virtual rollouts. :param clip: Whether to clip actions to spec before returning them. By default True. :param emit_log_probability: Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'. :param automatic_state_reset: If `True`, then `get_initial_policy_state` is used to clear state in `action()` and `distribution()` for for time steps where `time_step.is_first()`. :param observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example `observation_and_action_constraint_splitter` could be as simple as: ``` def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] ``` *Note*: when using `observation_and_action_constraint_splitter`, make sure the provided `q_network` is compatible with the network-specific half of the output of the `observation_and_action_constraint_splitter`. In particular, `observation_and_action_constraint_splitter` will be called on the observation before passing to the network. If `observation_and_action_constraint_splitter` is None, action constraints are not applied. :param validate_args: Python bool. Whether to verify inputs to, and outputs of, functions like `action` and `distribution` against spec structures, dtypes, and shapes. Research code may prefer to set this value to `False` to allow iterating on input and output structures without being hamstrung by overly rigid checking (at the cost of harder-to-debug errors). See also `TFAgent.validate_args`. :param name: A name for this module. Defaults to the class name. """ self.trajectory_optimiser = trajectory_optimiser self._environment_model = environment_model # making sure the batch_size in environment_model is correctly set self._environment_model.batch_size = self.trajectory_optimiser.batch_size super(PlanningPolicy, self).__init__( time_step_spec=environment_model.time_step_spec(), action_spec=environment_model.action_spec(), policy_state_spec=(), info_spec=(), clip=clip, emit_log_probability=emit_log_probability, automatic_state_reset=automatic_state_reset, observation_and_action_constraint_splitter= observation_and_action_constraint_splitter, validate_args=validate_args, name=name, )