def call(self, observation, step_type=None, network_state=None): num_outer_dims = nest_utils.get_outer_rank(observation, self.observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.') has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.to_float(nest.flatten(observation)[0]) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) for layer in self._input_layers: states = layer(states) states = batch_squash.unflatten(states) with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state, _ = rnn_utils.dynamic_unroll( self._cell, states, reset_mask, initial_state=network_state, dtype=tf.float32) states = batch_squash.flatten(states) for layer in self._output_layers: states = layer(states) value = self._value_projection_layer(states) value = tf.reshape(value, [-1]) value = batch_squash.unflatten(value) return value, network_state
def call(self, observation, step_type=None, network_state=()): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(observation, self.observation_spec) batch_squash = utils.BatchSquash(outer_rank) # Get single observation out regardless of nesting. states = tf.cast(nest.flatten(observation)[0], tf.float32) if self._batch_squash: states = batch_squash.flatten(states) for layer in self.layers: states = layer(states) if self._batch_squash: states = batch_squash.unflatten(states) return states, network_state
def call(self, observation, step_type, network_state=None): outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation, network_state = self._lstm_encoder( observation, step_type=step_type, network_state=network_state) states = batch_squash.flatten(observation) actions = [] for layer, spec in zip(self._action_layers, self._flat_action_spec): action = layer(states) action = common.scale_to_spec(action, spec) action = batch_squash.unflatten(action) actions.append(action) output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec, actions) return output_actions, network_state
def call(self, observations, step_type=(), network_state=(), training=False, mask=None): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) # We use batch_squash here in case the observations have a time sequence # compoment. batch_squash = utils.BatchSquash(outer_rank) observations = tf.nest.map_structure(batch_squash.flatten, observations) # we ignore next_state from the encoder state, network_state = self.encoder(observations, step_type=step_type, network_state=network_state) l_hidden = self.selector(state) selection_vector = tf.transpose(l_hidden, perm=[0, 2, 1]) options = [option(state) for option in self.options] options = tf.stack(options) options = tf.transpose(options, perm=[1, 0, 2]) # select an option using the selection_vector from the master network state = tf.matmul(selection_vector, options) state = batch_squash.unflatten(state) def call_projection_net(proj_net): distribution, _ = proj_net(state, outer_rank, training=training, mask=mask) return distribution output_actions = tf.nest.map_structure(call_projection_net, self.projection_nets) return output_actions, network_state
def call(self, observation, step_type=None, network_state=(), training=False): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation = tf.nest.map_structure(batch_squash.flatten, observation) if self._flat_preprocessing_layers is None: processed = observation else: processed = [] for obs, layer in zip( nest.flatten_up_to(self._preprocessing_nest, observation, check_types=False), self._flat_preprocessing_layers): processed.append(layer(obs, training=training)) if len(processed) == 1 and self._preprocessing_combiner is None: # If only one observation is passed and the preprocessing_combiner # is unspecified, use the preprocessed version of this observation. processed = processed[0] states = processed if self._preprocessing_combiner is not None: states = self._preprocessing_combiner(states) for layer in self._postprocessing_layers: states = layer(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state
def call(self, observations, step_type=(), network_state=()): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) # batch_squash, in case observations have a time sequence compoment. batch_squash = utils.BatchSquash(outer_rank) observations = tf.nest.map_structure(batch_squash.flatten, observations) flat_x = self._flat_x(observations) cnn_1 = self._cnn_1(observations) cnn_2 = self._cnn_2(cnn_1) flat_cnn = self._flat_cnn(cnn_2) concat_cnn_x = tf.keras.layers.concatenate([flat_x, flat_cnn]) dense_1 = self._dense_1(concat_cnn_x) dense_2 = self._dense_2(dense_1) concat_dense_x = tf.keras.layers.concatenate([flat_x, dense_2]) # policy_dense_1 = self._policy_dense_1(concat_dense_x) actions = self._action_projection_layer(concat_dense_x) return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state
def call(self, inputs, step_type=(), network_state=(), training=False): observation, action = inputs observation_spec, _ = self.input_tensor_spec num_outer_dims = nest_utils.get_outer_rank(observation, observation_spec) has_time_dim = num_outer_dims == 2 if has_time_dim: batch_squash = utils.BatchSquash(2) # Squash B, and T dims. # Flatten: [B, T, ...] -> [BxT, ...] observation = batch_squash.flatten(observation) action = batch_squash.flatten(action) q_value, network_state = super(CriticNet, self).call( (observation, action), step_type=step_type, network_state=network_state, training=training) if has_time_dim: q_value = batch_squash.unflatten( q_value) # [B x T, ...] -> [B, T, ...] return q_value, network_state
def call(self, observations, step_type, network_state): del step_type # unused. outer_rank = nest_utils.get_outer_rank(observations, self._observation_spec) observations = nest.flatten(observations) states = tf.to_float(observations[0]) # Reshape to only a single batch dimension for neural network functions. batch_squash = utils.BatchSquash(outer_rank) states = batch_squash.flatten(states) for layer in self._mlp_layers: states = layer(states) # TODO(oars): Can we avoid unflattening to flatten again states = batch_squash.unflatten(states) outputs = [ projection(states, outer_rank) for projection in self._projection_networks ] return nest.pack_sequence_as(self._action_spec, outputs), network_state
def call(self, observations, step_type=(), network_state=(), training=False, mask=None): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observations = tf.nest.map_structure(batch_squash.flatten, observations) state, network_state = self._encoder(observations, step_type=step_type, network_state=network_state) actions = self._action_projection_layer(state) actions = common_utils.scale_to_spec(actions, self._single_action_spec) actions = batch_squash.unflatten(actions) return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state
def call(self, inputs, outer_rank): # outer_rank is needed because the projection is not done on the raw # observations so getting the outer rank is hard as there is no spec to # compare to. batch_squash = utils.BatchSquash(outer_rank) inputs = batch_squash.flatten(inputs) means = self._projection_layer(inputs) means = tf.reshape(means, [-1] + self._output_spec.shape.as_list()) means = self._mean_transform(means, self._output_spec) means = tf.cast(means, self._output_spec.dtype) stds = self._bias(tf.zeros_like(means)) stds = tf.reshape(stds, [-1] + self._output_spec.shape.as_list()) stds = self._std_transform(stds) stds = tf.cast(stds, self._output_spec.dtype) means = batch_squash.unflatten(means) stds = batch_squash.unflatten(stds) return tfp.distributions.Normal(means, stds)
def call(self, observations, step_type=(), network_state=()): outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec) # We use batch_squash here in case the observations have a time sequence # compoment. batch_squash = utils.BatchSquash(outer_rank) observations = tf.nest.map_structure(batch_squash.flatten, observations) state, network_state = self._encoder(observations, step_type=step_type, network_state=network_state) actions = self._action_projection_layer(state) actions = common_utils.scale_to_spec(actions, self._action_spec) actions = batch_squash.unflatten(actions) return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state ####ACTOR TEST#### #action_spec = array_spec.BoundedArraySpec((6,), np.float32, minimum=0, maximum=10) #observation_spec = array_spec.BoundedArraySpec((64, 64, 3), np.float32, minimum=0, # maximum=255) # #random_env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec=action_spec) # ## Convert the environment to a TFEnv to generate tensors. #tf_env = tf_py_environment.TFPyEnvironment(random_env) # ##preprocessing_layers = { ## 'image': tf.keras.models.Sequential([tf.keras.layers.Conv2D(8, 4), ## tf.keras.layers.Flatten()]), ## 'vector': tf.keras.layers.Dense(5) ## } ##preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) #actor = ActorNetwork(tf_env.observation_spec(), # tf_env.action_spec()) # #time_step = tf_env.reset() ##print(actor(time_step.observation,time_step.step_type))
def call(self, inputs: tf.Tensor, batch_dims: int) \ -> Tuple[tfp.distributions.OneHotCategorical, Tuple]: """ Maps from a shared layer of hidden activations of the overall action net (inputs) to a distribution over actions for this head alone. :param inputs: The hidden activation from the final shared layer of the action network. :param batch_dims: The number of batch dimensions in the inputs. :return: A (OneHotCategorical) distribution over actions for this head and the network state (an empty tuple as this network is not stateful). """ # outer_rank is needed because the projection is not done on the raw observations so getting # the outer rank is hard as there is no spec to compare to. # BatchSquash is used to flatten and unflatten a tensor caching the original batch # dimension(s). batch_squash = network_utils.BatchSquash(batch_dims) # We project the logits via a linear transformation to the right dimension for the action # head. inputs = batch_squash.flatten(inputs) logits = self._projection_layer(inputs) # We finally return the appropriate TensorFlow distribution and the (empty) network state. return self.output_spec.build_distribution(logits=logits), ()
def call(self, inputs, step_type=(), network_state=()): outer_rank = nest_utils.get_outer_rank(inputs, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observations, actions = inputs observations, network_state = self._encoder( observations, step_type=step_type, network_state=network_state) observations = batch_squash.flatten(observations) actions = tf.cast(tf.nest.flatten(actions)[0], tf.float32) actions = batch_squash.flatten(actions) for layer in self._action_layers: actions = layer(actions) joint = tf.concat([observations, actions], -1) for layer in self._joint_layers: joint = layer(joint) q_value = tf.reshape(joint, [-1]) q_value = batch_squash.unflatten(q_value) return q_value, network_state
def call(self, inputs, outer_rank): if inputs.dtype != self._sample_spec.dtype: raise ValueError( 'Inputs to NormalProjectionNetwork must match the sample_spec.dtype.' ) # outer_rank is needed because the projection is not done on the raw # observations so getting the outer rank is hard as there is no spec to # compare to. batch_squash = network_utils.BatchSquash(outer_rank) inputs = batch_squash.flatten(inputs) means = self._means_projection_layer(inputs) means = tf.reshape(means, [-1] + self._sample_spec.shape.as_list()) if self._state_dependent_std: stds = self._stddev_projection_layer(inputs) else: stds = self._bias(tf.zeros_like(means)) stds = tf.reshape(stds, [-1] + self._sample_spec.shape.as_list()) inv_stds = self._std_transform(stds) if self._max_std is not None: inv_stds += 1 / (self._max_std - self._min_std) stds = 1. / inv_stds if self._min_std > 0: stds += self._min_std stds = tf.cast(stds, self._sample_spec.dtype) means = means * stds # If not scaling the distribution later, use a normalized mean. if not self._scale_distribution and self._mean_transform is not None: means = self._mean_transform(means, self._sample_spec) means = tf.cast(means, self._sample_spec.dtype) means = batch_squash.unflatten(means) stds = batch_squash.unflatten(stds) return self.output_spec.build_distribution(loc=means, scale=stds)
def call(self, inputs, outer_rank, training=False, mask=None): if inputs.dtype != self._sample_spec.dtype: raise ValueError( 'Inputs to NormalProjectionNetwork must match the sample_spec.dtype.' ) if mask is not None: raise NotImplementedError( 'NormalProjectionNetwork does not yet implement action masking; got ' 'mask={}'.format(mask)) # outer_rank is needed because the projection is not done on the raw # observations so getting the outer rank is hard as there is no spec to # compare to. batch_squash = network_utils.BatchSquash(outer_rank) inputs = batch_squash.flatten(inputs) means = self._means_projection_layer(inputs, training=training) means = tf.reshape(means, [-1] + self._sample_spec.shape.as_list()) # If scaling the distribution later, use a normalized mean. if not self._scale_distribution and self._mean_transform is not None: means = self._mean_transform(means, self._sample_spec) means = tf.cast(means, self._sample_spec.dtype) if self._state_dependent_std: stds = self._stddev_projection_layer(inputs, training=training) else: stds = self._bias(tf.zeros_like(means), training=training) stds = tf.reshape(stds, [-1] + self._sample_spec.shape.as_list()) if self._std_transform is not None: stds = self._std_transform(stds) stds = tf.cast(stds, self._sample_spec.dtype) means = batch_squash.unflatten(means) stds = batch_squash.unflatten(stds) return self.output_spec.build_distribution(loc=means, scale=stds), ()
def call(self, inputs, step_type, network_state=None): outer_rank = nest_utils.get_outer_rank(inputs, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) # Squash B, and T dims. observation, action = inputs observation, _ = self._obs_encoder(observation, step_type=step_type, network_state=network_state) output, network_state = self._lstm_encoder(inputs=(observation, action), step_type=step_type, network_state=network_state) output = batch_squash.flatten(output) for layer in self._output_layers: output = layer(output) q_value = tf.reshape(output, [-1]) q_value = batch_squash.unflatten(q_value) return q_value, network_state
def call(self, inputs, unused_step_type=None, network_state=()): hidden_state = tf.cast(tf.nest.flatten(inputs), tf.float32)[0] # Calls coming from agent.train() has a time dimension. Direct loss calls # may not have a time dimension. It order to make BatchSquash work, we need # to specify the outer dimension properly. has_time_dim = nest_utils.get_outer_rank(inputs, self.input_tensor_spec) == 2 outer_rank = 2 if has_time_dim else 1 batch_squash = network_utils.BatchSquash(outer_rank) hidden_state = batch_squash.flatten(hidden_state) for layer in self.layers: hidden_state = layer(hidden_state) actions, stdevs = tf.split(hidden_state, 2, axis=1) actions = batch_squash.unflatten(actions) stdevs = batch_squash.unflatten(stdevs) actions = tf.nest.pack_sequence_as(self._action_spec, [actions]) stdevs = tf.nest.pack_sequence_as(self._action_spec, [stdevs]) return self.output_spec.build_distribution( loc=actions, scale=stdevs), network_state
def call(self, observations, step_type=(), network_state=(), training=False): num_outer_dims = nest_utils.get_outer_rank(observations, self.input_tensor_spec) has_time_dim = num_outer_dims == 2 if has_time_dim: batch_squash = utils.BatchSquash(2) # Squash B, and T dims. # Flattening: [B, T, ...] -> [BxT, ...] observations = batch_squash.flatten(observations) z = self._z_encoder(observations, training=training) z = z.sample() if has_time_dim: z = batch_squash.unflatten(z) self._input_tensor_spec = self._z_spec output = super(RecurrentActorNet, self).call(z, step_type=step_type, network_state=network_state, training=training) self._input_tensor_spec = self._s_spec return output
def call(self, inputs, outer_rank): if inputs.dtype != self._sample_spec.dtype: raise ValueError( 'Inputs to NormalProjectionNetwork must match the sample_spec.dtype.') # outer_rank is needed because the projection is not done on the raw # observations so getting the outer rank is hard as there is no spec to # compare to. batch_squash = utils.BatchSquash(outer_rank) inputs = batch_squash.flatten(inputs) means = self._projection_layer(inputs) means = tf.reshape(means, [-1] + self._sample_spec.shape.as_list()) means = self._mean_transform(means, self._sample_spec) means = tf.cast(means, self._sample_spec.dtype) stds = self._bias(tf.zeros_like(means)) stds = tf.reshape(stds, [-1] + self._sample_spec.shape.as_list()) stds = self._std_transform(stds) stds = tf.cast(stds, self._sample_spec.dtype) means = batch_squash.unflatten(means) stds = batch_squash.unflatten(stds) return self.output_spec.build_distribution(loc=means, scale=stds)
def normal(inputs, output_spec, outer_rank=1, projection_layer=default_fully_connected, mean_transform=tanh_squash_to_spec, std_initializer=tf.zeros_initializer(), std_transform=tf.exp, distribution_cls=tfp.distributions.Normal): """Project a batch of inputs to a batch of means and standard deviations. Given an output spec for a single tensor continuous action, produces a neural net layer converting inputs to a normal distribution matching the spec. The mean is derived from a fully connected linear layer as mean_transform(layer_output, output_spec). The std is fixed to a single trainable tensor (thus independent of the inputs). Specifically, std is parameterized as std_transform(variable). Args: inputs: An input Tensor of shape [batch_size, ?]. output_spec: An output spec (either BoundedArraySpec or BoundedTensorSpec). outer_rank: The number of outer dimensions of inputs to consider batch dimensions and to treat as batch dimensions of output distribution. projection_layer: Function taking in inputs, num_elements, scope and returning a projection of inputs to a Tensor of width num_elements. mean_transform: A function taking in layer output and the output_spec, returning the means. Defaults to tanh_squash_to_spec. std_initializer: Initializer for std_dev variables. std_transform: The function applied to the trainable std variable. For example, tf.exp (default), tf.nn.softplus. distribution_cls: The distribution class to use for output distribution. Default is tfp.distributions.Normal. Returns: A tf.distribution.Normal object in which the standard deviation is not dependent on input. Raises: ValueError: If output_spec is invalid. """ if not tensor_spec.is_bounded(output_spec): raise ValueError('Input output_spec is of invalid type ' '%s.' % type(output_spec)) if not tensor_spec.is_continuous(output_spec): raise ValueError('Output is not continuous.') batch_squash = utils.BatchSquash(outer_rank) inputs = batch_squash.flatten(inputs) means = projection_layer(inputs, output_spec.shape.num_elements(), scope='means') stds = tf.contrib.layers.bias_add( tf.zeros_like(means), # Independent of inputs. initializer=std_initializer, scope='stds', activation_fn=None) means = tf.reshape(means, [-1] + output_spec.shape.as_list()) means = mean_transform(means, output_spec) means = tf.cast(means, output_spec.dtype) stds = tf.reshape(stds, [-1] + output_spec.shape.as_list()) stds = std_transform(stds) stds = tf.cast(stds, output_spec.dtype) means, stds = batch_squash.unflatten(means), batch_squash.unflatten(stds) return distribution_cls(means, stds)
def call(self, inputs, step_type, network_state=(), training=False): observation, action = inputs observation_spec, _ = self.input_tensor_spec num_outer_dims = nest_utils.get_outer_rank(observation, observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.' ) has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) action = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), action) step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) observation = tf.cast(tf.nest.flatten(observation)[0], tf.float32) action = tf.cast(tf.nest.flatten(action)[0], tf.float32) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. observation = batch_squash.flatten( observation) # [B, T, ...] -> [BxT, ...] action = batch_squash.flatten(action) for layer in self._observation_layers: observation = layer(observation, training=training) for layer in self._action_layers: action = layer(action, training=training) joint = tf.concat([observation, action], -1) for layer in self._joint_layers: joint = layer(joint, training=training) joint = batch_squash.unflatten(joint) # [B x T, ...] -> [B, T, ...] network_kwargs = {} if isinstance(self._lstm_network, dynamic_unroll_layer.DynamicUnroll): network_kwargs['reset_mask'] = tf.equal(step_type, time_step.StepType.FIRST, name='mask') # Unroll over the time sequence. output = self._lstm_network(inputs=joint, initial_state=network_state, training=training, **network_kwargs) if isinstance(self._lstm_network, dynamic_unroll_layer.DynamicUnroll): joint, network_state = output else: joint = output[0] network_state = tf.nest.pack_sequence_as( self._lstm_network.cell.state_size, tf.nest.flatten(output[1:])) output = batch_squash.flatten(joint) # [B, T, ...] -> [B x T, ...] for layer in self._output_layers: output = layer(output, training=training) q_value = tf.reshape(output, [-1]) q_value = batch_squash.unflatten( q_value) # [B x T, ...] -> [B, T, ...] if not has_time_dim: q_value = tf.squeeze(q_value, axis=1) return q_value, network_state
def squash_dataset_element(sequence, info): return tf.nest.map_structure( utils.BatchSquash(2).flatten, (sequence, info))
def call(self, observation, step_type, network_state=None, training=False): # Preprocess for multiple observations if self._flat_preprocessing_layers is None: processed = observation else: processed = [] for obs, layer in zip( nest.flatten_up_to( self._preprocessing_nest, observation, check_types=False), self._flat_preprocessing_layers): processed.append(layer(obs, training=training)) if len(processed) == 1 and self._preprocessing_combiner is None: # If only one observation is passed and the preprocessing_combiner # is unspecified, use the preprocessed version of this observation. processed = processed[0] observation = processed if self._preprocessing_combiner is not None: observation = self._preprocessing_combiner(observation) observation_spec = tensor_spec.TensorSpec((observation.shape[-1],), dtype=observation.dtype) num_outer_dims = nest_utils.get_outer_rank(observation, observation_spec) if num_outer_dims not in (1, 2): raise ValueError( 'Input observation must have a batch or batch x time outer shape.') has_time_dim = num_outer_dims == 2 if not has_time_dim: # Add a time dimension to the inputs. observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), observation) step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1), step_type) states = tf.cast(tf.nest.flatten(observation)[0], tf.float32) batch_squash = utils.BatchSquash(2) # Squash B, and T dims. states = batch_squash.flatten(states) # [B, T, ...] -> [B x T, ...] for layer in self._input_layers: states = layer(states, training=training) states = batch_squash.unflatten(states) # [B x T, ...] -> [B, T, ...] with tf.name_scope('reset_mask'): reset_mask = tf.equal(step_type, time_step.StepType.FIRST) # Unroll over the time sequence. states, network_state = self._dynamic_unroll( states, reset_mask, initial_state=network_state, training=training) states = batch_squash.flatten(states) # [B, T, ...] -> [B x T, ...] for layer in self._output_layers: states = layer(states, training=training) actions = [] for layer, spec in zip(self._action_layers, self._flat_action_spec): action = layer(states, training=training) action = common.scale_to_spec(action, spec) action = batch_squash.unflatten(action) # [B x T, ...] -> [B, T, ...] if not has_time_dim: action = tf.squeeze(action, axis=1) actions.append(action) output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec, actions) return output_actions, network_state
def _apply_actor_network(self, time_step, policy_state): has_batch_dim = time_step.step_type.shape.as_list()[0] > 1 observation = time_step.observation if self._observation_normalizer: observation = self._observation_normalizer.normalize(observation) actions, policy_state = self._actor_network(observation, time_step.step_type, policy_state, training=self._training) if has_batch_dim: return actions, policy_state # samples "best" safe action out of 50 sampled_ac = actions.sample(50) obs = nest_utils.stack_nested_tensors( [time_step.observation for _ in range(50)]) obs_outer_rank = nest_utils.get_outer_rank( obs, self.time_step_spec.observation) ac_outer_rank = nest_utils.get_outer_rank(sampled_ac, self.action_spec) obs_batch_squash = utils.BatchSquash(obs_outer_rank) ac_batch_squash = utils.BatchSquash(ac_outer_rank) obs = tf.nest.map_structure(obs_batch_squash.flatten, obs) sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten, sampled_ac) q_val, _ = self._safety_critic_network((obs, sampled_ac), time_step.step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_mask = fail_prob < self._safety_threshold safe_ac_idx = tf.where(safe_ac_mask) resample_count = 0 start_time = time.time() while self._training and resample_count < 4 and not safe_ac_idx.shape.as_list( )[0]: if self._resample_counter is not None: self._resample_counter() resample_count += 1 if isinstance(actions, dist_utils.SquashToSpecNormal): scale = actions.input_distribution.scale * 1.5 # increase variance by constant 1.5 ac_mean = actions.mean() else: scale = actions.scale * 1.5 ac_mean = actions.mean() actions = self._actor_network.output_spec.build_distribution( loc=ac_mean, scale=scale) sampled_ac = actions.sample(50) sampled_ac = tf.nest.map_structure(ac_batch_squash.flatten, sampled_ac) q_val, _ = self._safety_critic_network((obs, sampled_ac), time_step.step_type) fail_prob = tf.nn.sigmoid(q_val) safe_ac_idx = tf.where(fail_prob < self._safety_threshold) # logging.debug('resampled {} times, {} seconds'.format(resample_count, time.time() - start_time)) sampled_ac = ac_batch_squash.unflatten(sampled_ac) if None in safe_ac_idx.shape.as_list() or not np.prod( safe_ac_idx.shape.as_list()): # return safest action safe_idx = tf.argmin(fail_prob) else: sampled_ac = tf.gather(sampled_ac, safe_ac_idx) fail_prob_safe = tf.gather(fail_prob, safe_ac_idx) if self._training: safe_idx = tf.argmax(fail_prob_safe)[ 0] # picks most unsafe action out of "safe" options else: safe_idx = tf.argmin(fail_prob_safe)[0] ac = sampled_ac[safe_idx] assert ac.shape.as_list( )[0] == 1, 'action shape is not correct: {}'.format(ac.shape.as_list()) return ac, policy_state
def _loss(self, experience, td_errors_loss_fn=tf.losses.huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes critic loss for CategoricalDQN training. See Algorithm 1 and the discussion immediately preceding it in page 6 of "A Distributional Perspective on Reinforcement Learning" Bellemare et al., 2017 https://arxiv.org/abs/1707.06887 Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.required_experience_time_steps` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional weights used for importance sampling. Returns: critic_loss: A scalar critic loss. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) if self._n_step_update == 1: time_steps, actions, next_time_steps = self._experience_to_transitions( experience) else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, actions, _ = self._experience_to_transitions( first_two_steps) _, _, next_time_steps = self._experience_to_transitions( last_two_steps) with tf.name_scope('critic_loss'): tf.nest.assert_same_structure(actions, self.action_spec) tf.nest.assert_same_structure(time_steps, self.time_step_spec) tf.nest.assert_same_structure(next_time_steps, self.time_step_spec) rank = nest_utils.get_outer_rank(time_steps.observation, self._time_step_spec.observation) # If inputs have a time dimension and the q_network is stateful, # combine the batch and time dimension. batch_squash = (None if rank <= 1 or self._q_network.state_spec in ((), None) else utils.BatchSquash(rank)) # q_logits contains the Q-value logits for all actions. q_logits, _ = self._q_network(time_steps.observation, time_steps.step_type) next_q_distribution = self._next_q_distribution( next_time_steps, batch_squash) if batch_squash is not None: # Squash outer dimensions to a single dimensions for facilitation # computing the loss the following. Required for supporting temporal # inputs, for example. q_logits = batch_squash.flatten(q_logits) actions = batch_squash.flatten(actions) next_time_steps = tf.nest.map_structure( batch_squash.flatten, next_time_steps) actions = tf.nest.flatten(actions)[0] if actions.shape.ndims > 1: actions = tf.squeeze(actions, range(1, actions.shape.ndims)) # Project the sample Bellman update \hat{T}Z_{\theta} onto the original # support of Z_{\theta} (see Figure 1 in paper). batch_size = tf.shape(q_logits)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_update == 1: discount = next_time_steps.discount if discount.shape.ndims == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = discount[:, None] next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.ndims == 1: # See the explanation above. reward = reward[:, None] reward_term = tf.multiply(reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, gamma * next_value_term, name='target_support') else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. # TODO(b/131557265): Replace value_ops.discounted_return with a method # that only computes the single value needed. discounted_rewards = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False) # We only need the first value within the time dimension which # corresponds to the full final return. The remaining values are only # partial returns. discounted_rewards = discounted_rewards[:, :1] final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = final_value_discount[:, None] # Save the values of discounted_rewards and final_value_discount in # order to check them in unit tests. self._discounted_rewards = discounted_rewards self._final_value_discount = final_value_discount target_support = tf.add(discounted_rewards, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( project_distribution(target_support, next_q_distribution, self._support)) # Obtain the current Q-value logits for the selected actions. indices = tf.range(tf.shape(q_logits)[0])[:, None] indices = tf.cast(indices, actions.dtype) reshaped_actions = tf.concat([indices, actions[:, None]], 1) chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions) # Compute the cross-entropy loss between the logits. If inputs have # a time dimension, compute the sum over the time dimension before # computing the mean over the batch dimension. if batch_squash is not None: target_distribution = batch_squash.unflatten( target_distribution) chosen_action_logits = batch_squash.unflatten( chosen_action_logits) critic_loss = tf.reduce_mean( tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits), axis=1)) else: critic_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits)) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar('critic_loss', critic_loss, step=self.train_step_counter) if self._debug_summaries: distribution_errors = target_distribution - chosen_action_logits with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) # TODO(b/127318640): Give appropriate values for td_loss and td_error for # prioritized replay. return tf_agent.LossInfo( critic_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
def actor_loss(self, time_steps, actions, next_time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ prev_time_steps, prev_actions, time_steps = time_steps, actions, next_time_steps # pylint: disable=line-too-long with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values1, _ = self._critic_network_1( target_input, step_type=time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2( target_input, step_type=time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values ### Flatten time dimension. We'll add it back when adding the loss. num_outer_dims = nest_utils.get_outer_rank(time_steps, self.time_step_spec) has_time_dim = (num_outer_dims == 2) if has_time_dim: batch_squash = utils.BatchSquash(2) # Squash B, and T dims. obs = batch_squash.flatten(time_steps.observation) prev_obs = batch_squash.flatten(prev_time_steps.observation) prev_actions = batch_squash.flatten(prev_actions) else: obs = time_steps.observation prev_obs = prev_time_steps.observation z = self._actor_network._z_encoder(obs, training=True) # pylint: disable=protected-access prior = self._actor_network._predictor((prev_obs, prev_actions), # pylint: disable=protected-access training=True) # kl is a vector of length batch_size, which has already been summed over # the latent dimension z. kl = tfp.distributions.kl_divergence(z, prior) if has_time_dim: kl = batch_squash.unflatten(kl) kl_coef = tf.stop_gradient( tf.exp(self._actor_network._log_kl_coefficient)) # pylint: disable=protected-access # The actor loss trains both the predictor and the encoder. actor_loss += kl_coef * kl if actor_loss.shape.rank > 1: # Sum over the time dimension. actor_loss = tf.reduce_sum( actor_loss, axis=range(1, actor_loss.shape.rank)) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses( per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) tf.compat.v2.summary.scalar( name='encoder_kl', data=tf.reduce_mean(kl), step=self.train_step_counter) return actor_loss