def _get_mean_var_estimates(self): """Returns this normalizer's current estimates for mean & variance.""" mean_estimate = nest_utils.map_structure_up_to(self._flat_tensor_spec, lambda a, b: a / b, self._mean_sum, self._count) var_estimate = nest_utils.map_structure_up_to(self._flat_tensor_spec, lambda a, b: a / b, self._var_sum, self._count) return mean_estimate, var_estimate
def distribution_from_spec( spec: types.NestedTensorSpec, new_distribution_params: types.NestedTensor, legacy_distribution_network: bool) -> types.NestedDistribution: """Convert `(spec, new_distribution_params) -> Distribution`. `new_distribution_params` typically comes from a logged policy info. Args: spec: A nested tensor spec. If `legacy_distribution_network` is `True`, these are typically `actor_net.output_spec`. If it's `False`, these are typically the output of `actor_net.create_variables()`. new_distribution_params: Parameters to merge with the spec to create a new distribution. These were typically emitted by `get_distribution_params` and stored in the replay buffer. legacy_distribution_network: `True` if the spec and params were generated from a `network.DistributionNetwork`. Returns: A (possibly nested set of) `Distribution` created from the spec merged with the new params. """ if legacy_distribution_network: return distribution_spec.nested_distributions_from_specs( spec, new_distribution_params) else: def merge_and_convert(spec, params): return distribution_utils.make_from_parameters( distribution_utils.merge_to_parameters_from_dict( spec.parameters, params)) return nest_utils.map_structure_up_to(spec, merge_and_convert, spec, new_distribution_params)
def normalize(self, tensor, clip_value=5.0, center_mean=True, variance_epsilon=1e-3): """Applies normalization to tensor. Args: tensor: Tensor to normalize. clip_value: Clips normalized observations between +/- this value if clip_value > 0, otherwise does not apply clipping. center_mean: If true, subtracts off mean from normalized tensor. variance_epsilon: Epsilon to avoid division by zero in normalization. Returns: normalized_tensor: Tensor after applying normalization. """ nest_utils.assert_same_structure(tensor, self._tensor_spec) tensor = tf.nest.flatten(tensor) tensor = tf.nest.map_structure(lambda t: tf.cast(t, tf.float32), tensor) with tf.name_scope(self._scope + '/normalize'): mean_estimate, var_estimate = self._get_mean_var_estimates() mean = (mean_estimate if center_mean else tf.nest.map_structure( tf.zeros_like, mean_estimate)) def _normalize_single_tensor(single_tensor, single_mean, single_var): return tf.nn.batch_normalization( single_tensor, single_mean, single_var, offset=None, scale=None, variance_epsilon=variance_epsilon, name='normalized_tensor') normalized_tensor = nest_utils.map_structure_up_to( self._flat_tensor_spec, _normalize_single_tensor, tensor, mean, var_estimate, check_types=False) if clip_value > 0: def _clip(t): return tf.clip_by_value(t, -clip_value, clip_value, name='clipped_normalized_tensor') normalized_tensor = tf.nest.map_structure( _clip, normalized_tensor) normalized_tensor = tf.nest.pack_sequence_as(self._tensor_spec, normalized_tensor) return normalized_tensor
def call(self, inputs, network_state=(), **kwargs): nest_utils.assert_same_structure( self._nested_layers, inputs, allow_shallow_nest1=True, message= ('`self.nested_layers` and `inputs` do not have matching structures' )) if network_state: nest_utils.assert_same_structure( self.state_spec, network_state, allow_shallow_nest1=True, message= ('network_state and state_spec do not have matching structure' )) nested_layers_state = network_state else: nested_layers_state = tf.nest.map_structure( lambda _: (), self._nested_layers) # Here we must use map_structure_up_to because nested_layers_state has a # "deeper" structure than self._nested_layers. For example, an LSTM # layer's state is composed of a list with two tensors. The # tf.nest.map_structure function would raise an error if two # "incompatible" structures are passed in this way. def _mapper(inp, layer, state): # pylint: disable=invalid-name return layer(inp, network_state=state, **kwargs) outputs_and_next_state = nest_utils.map_structure_up_to( self._nested_layers, _mapper, inputs, self._nested_layers, nested_layers_state) flat_outputs_and_next_state = nest_utils.flatten_up_to( self._nested_layers, outputs_and_next_state) flat_outputs, flat_next_state = zip(*flat_outputs_and_next_state) outputs = tf.nest.pack_sequence_as(self._nested_layers, flat_outputs) next_network_state = tf.nest.pack_sequence_as(self._nested_layers, flat_next_state) return outputs, next_network_state
def call(self, observations, step_type, network_state): del step_type states = tf.cast(tf.nest.flatten(observations)[0], tf.float32) for layer in self._dummy_layers: states = layer(states) single_action_spec = tf.nest.flatten(self._output_tensor_spec)[0] # action_spec is TensorSpec([1], ...) so make sure there's an outer dim. actions = states[..., 0] stdevs = states[..., 1] actions = tf.reshape(actions, [-1] + single_action_spec.shape.as_list()) stdevs = tf.reshape(stdevs, [-1] + single_action_spec.shape.as_list()) actions = tf.nest.pack_sequence_as(self._output_tensor_spec, [actions]) stdevs = tf.nest.pack_sequence_as(self._output_tensor_spec, [stdevs]) distribution = nest_utils.map_structure_up_to( self._output_tensor_spec, tfp.distributions.MultivariateNormalDiag, actions, stdevs) return distribution, network_state