def _check_network_output(self, net, label): """Check outputs of q_net and target_q_net against expected shape. Subclasses that require different q_network outputs should override this function. Args: net: A `Network`. label: A label to print in case of a mismatch. """ network_utils.check_single_floating_network_output( net.create_variables(), expected_output_shape=(self._num_actions,), label=label)
def __init__( self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, q_network: network.Network, emit_log_probability: bool = False, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None, validate_action_spec_and_network: bool = True, name: Optional[Text] = None): """Builds a Q-Policy given a q_network. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. q_network: An instance of a `tf_agents.network.Network`, callable via `network(observation, step_type) -> (output, final_state)`. emit_log_probability: Whether to emit log-probs in info of `PolicyStep`. observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example `observation_and_action_constraint_splitter` could be as simple as: ``` def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] ``` *Note*: when using `observation_and_action_constraint_splitter`, make sure the provided `q_network` is compatible with the network-specific half of the output of the `observation_and_action_constraint_splitter`. In particular, `observation_and_action_constraint_splitter` will be called on the observation before passing to the network. If `observation_and_action_constraint_splitter` is None, action constraints are not applied. validate_action_spec_and_network: If `True` (default), action_spec is checked to make sure it is a single scalar spec with a minimum of zero. Also validates that the network's output matches the spec. name: The name of this policy. All variables in this module will fall under that name. Defaults to the class name. Raises: ValueError: If `q_network.action_spec` exists and is not compatible with `action_spec`. NotImplementedError: If `action_spec` contains more than one `BoundedTensorSpec`. """ action_spec = tensor_spec.from_spec(action_spec) time_step_spec = tensor_spec.from_spec(time_step_spec) network_action_spec = getattr(q_network, 'action_spec', None) if network_action_spec is not None: action_spec = cast(tf.TypeSpec, action_spec) if not action_spec.is_compatible_with(network_action_spec): raise ValueError( 'action_spec must be compatible with q_network.action_spec; ' 'instead got action_spec=%s, q_network.action_spec=%s' % ( action_spec, network_action_spec)) flat_action_spec = tf.nest.flatten(action_spec) if len(flat_action_spec) > 1: raise ValueError( 'Only scalar actions are supported now, but action spec is: {}' .format(action_spec)) if validate_action_spec_and_network: spec = flat_action_spec[0] if spec.shape.rank > 0: raise ValueError( 'Only scalar actions are supported now, but action spec is: {}' .format(action_spec)) if spec.minimum != 0: raise ValueError( 'Action specs should have minimum of 0, but saw: {0}'.format(spec)) num_actions = spec.maximum - spec.minimum + 1 network_utils.check_single_floating_network_output( q_network.create_variables(), (num_actions,), str(q_network)) # We need to maintain the flat action spec for dtype, shape and range. self._flat_action_spec = flat_action_spec[0] self._q_network = q_network super(QPolicy, self).__init__( time_step_spec, action_spec, policy_state_spec=q_network.state_spec, clip=False, emit_log_probability=emit_log_probability, observation_and_action_constraint_splitter=( observation_and_action_constraint_splitter), name=name)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, q_network: network.Network, min_q_value: float, max_q_value: float, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None, temperature: types.Float = 1.0): """Builds a categorical Q-policy given a categorical Q-network. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A `BoundedTensorSpec` representing the actions. q_network: A network.Network to use for our policy. min_q_value: A float specifying the minimum Q-value, used for setting up the support. max_q_value: A float specifying the maximum Q-value, used for setting up the support. observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example `observation_and_action_constraint_splitter` could be as simple as: ``` def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] ``` *Note*: when using `observation_and_action_constraint_splitter`, make sure the provided `q_network` is compatible with the network-specific half of the output of the `observation_and_action_constraint_splitter`. In particular, `observation_and_action_constraint_splitter` will be called on the observation before passing to the network. If `observation_and_action_constraint_splitter` is None, action constraints are not applied. temperature: temperature for sampling, when close to 0.0 is arg_max. Raises: ValueError: if `q_network` does not have property `num_atoms`. TypeError: if `action_spec` is not a `BoundedTensorSpec`. """ network_action_spec = getattr(q_network, 'action_spec', None) if network_action_spec is not None: action_spec = cast(tf.TypeSpec, action_spec) if not action_spec.is_compatible_with(network_action_spec): raise ValueError( 'action_spec must be compatible with q_network.action_spec; ' 'instead got action_spec=%s, q_network.action_spec=%s' % (action_spec, network_action_spec)) if not isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( 'action_spec must be a BoundedTensorSpec. Got: %s' % (action_spec, )) action_spec = cast(tensor_spec.BoundedTensorSpec, action_spec) if action_spec.minimum != 0: raise ValueError( 'Action specs should have minimum of 0, but saw: {0}. If collecting ' 'from a python environment, consider using ' 'tf_agents.environments.wrappers.ActionOffsetWrapper.'.format( action_spec)) num_actions = action_spec.maximum - action_spec.minimum + 1 try: num_atoms = q_network.num_atoms except AttributeError: raise ValueError( 'Expected q_network to have property `num_atoms`, but ' 'it doesn\'t. (Note: you likely want to use a ' 'CategoricalQNetwork.) Network is: %s' % q_network) self._num_atoms = num_atoms network_utils.check_single_floating_network_output( q_network.create_variables(), (num_actions, num_atoms), str(q_network)) super(CategoricalQPolicy, self).__init__(time_step_spec, action_spec, policy_state_spec=q_network.state_spec, observation_and_action_constraint_splitter=( observation_and_action_constraint_splitter)) self._temperature = tf.convert_to_tensor(temperature, dtype=tf.float32) self._q_network = q_network # Generate support in numpy so that we can assign it to a constant and avoid # having a tensor property. support = np.linspace(min_q_value, max_q_value, self._num_atoms, dtype=np.float32) self._support = tf.constant(support, dtype=tf.float32) self._action_dtype = action_spec.dtype
def _check_network_output(self, net, label): network_utils.check_single_floating_network_output( net.create_variables(), expected_output_shape=(self._num_actions, self._num_atoms), label=label)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, q_network: network.Network, sampler: cem_actions_sampler.ActionsSampler, init_mean: types.NestedArray, init_var: types.NestedArray, actor_policy: Optional[tf_policy.TFPolicy] = None, minimal_var: float = 0.0001, info_spec: types.NestedSpecTensorOrArray = (), num_samples: int = 32, num_elites: int = 4, num_iterations: int = 32, emit_log_probability: bool = False, preprocess_state_action: bool = True, training: bool = False, weights: types.NestedTensorOrArray = None, name: Optional[str] = None): """Builds a CEM-Policy given a network and a sampler. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. q_network: An instance of a `tf_agents.network.Network`, callable via `network(observation, step_type) -> (output, final_state)`. sampler: Samples the actions needed for the CEM. init_mean: A list or tuple or scalar, reprenting initial mean for actions. init_var: A list or tuple or scalar, reprenting initial var for actions. actor_policy: Optional actor policy. minimal_var: Minimal variance to prevent CEM distributon collapsing. info_spec: A policy info spec. num_samples: Number of samples to sample each round. num_elites: Number of best actions each round to refit the distribution with. num_iterations: Number of iterations to run the CEM loop. emit_log_probability: Whether to emit log-probs in info of `PolicyStep`. preprocess_state_action: The shape of state is (B, ...) and the shape of action is (B, N, A). When preprocess_state_action is enabled, the state will be tile_batched to be (BxN, ...) and the action will be reshaped to be (BxN, A). When preprocess_state_action is not enabled, the same operation needs to be done inside the network. This is helpful when the input have large memory requirements and the replication of state could happen after a few layers inside the network. training: Whether it is in training mode or inference mode. weights: A nested structure of weights w/ the same structure as action. name: The name of this policy. All variables in this module will fall under that name. Defaults to the class name. Raises: ValueError: If `q_network.action_spec` exists and is not compatible with `action_spec`. """ network_action_spec = getattr(q_network, 'action_spec', None) if network_action_spec is not None: if not action_spec.is_compatible_with(network_action_spec): raise ValueError( 'action_spec must be compatible with q_network.action_spec; ' 'instead got action_spec=%s, q_network.action_spec=%s' % (action_spec, network_action_spec)) if q_network: network_utils.check_single_floating_network_output( q_network.create_variables(), expected_output_shape=(), label=str(q_network)) policy_state_spec = q_network.state_spec else: policy_state_spec = () self._actor_policy = actor_policy self._q_network = q_network self._init_mean = init_mean self._init_var = init_var self._minimal_var = minimal_var self._num_samples = num_samples # N self._num_elites = num_elites # M self._num_iterations = num_iterations self._actions_sampler = sampler self._observation_spec = time_step_spec.observation self._training = training self._preprocess_state_action = preprocess_state_action self._weights = weights super(CEMPolicy, self).__init__(time_step_spec, action_spec, info_spec=info_spec, policy_state_spec=policy_state_spec, clip=False, emit_log_probability=emit_log_probability, name=name)