def _critic_train_step(self, exp: Experience, state: SacCriticState, action, log_pi): if self._is_continuous: critic_input = (exp.observation, exp.action) target_critic_input = (exp.observation, action) else: critic_input = exp.observation target_critic_input = exp.observation critic1, critic1_state = self._critic_network1( critic_input, step_type=exp.step_type, network_state=state.critic1) critic2, critic2_state = self._critic_network2( critic_input, step_type=exp.step_type, network_state=state.critic2) target_critic1, target_critic1_state = self._target_critic_network1( target_critic_input, step_type=exp.step_type, network_state=state.target_critic1) target_critic2, target_critic2_state = self._target_critic_network2( target_critic_input, step_type=exp.step_type, network_state=state.target_critic2) if not self._is_continuous: exp_action = tf.cast(exp.action, tf.int32) critic1 = tfa_common.index_with_actions(critic1, exp_action) critic2 = tfa_common.index_with_actions(critic2, exp_action) sampled_action = tf.cast(action, tf.int32) target_critic1 = tfa_common.index_with_actions( target_critic1, sampled_action) target_critic2 = tfa_common.index_with_actions( target_critic2, sampled_action) target_critic = (tf.minimum(target_critic1, target_critic2) - tf.stop_gradient(tf.exp(self._log_alpha) * log_pi)) state = SacCriticState(critic1=critic1_state, critic2=critic2_state, target_critic1=target_critic1_state, target_critic2=target_critic2_state) info = SacCriticInfo(critic1=critic1, critic2=critic2, target_critic=target_critic) return state, info
def _compute_next_q_values(self, target_policies, time_steps): """Compute the q value of the next state for TD error computation. Args: policies: list of target HeteroQPolicy object time_steps: A batch of current timesteps Returns: A tensor of Q values for the given next state. """ q_values_seq, actions = self._sequential_network_activation(target_policies, time_steps) action_key = 'raw' raw_actions = tf.unstack(actions[action_key], axis=1) multi_dim_actions = False values = [common.index_with_actions( self._append2logits(q_values), tf.cast(act, dtype=tf.int32), multi_dim_actions=multi_dim_actions) for q_values, act in zip(q_values_seq, raw_actions)] # due to dist.mode() in GreedyPolicy, 0 is selected for masked action. So we need to remove -inf values = [tf.where(tf.equal(v, NEG_INF), 0.0, v) for v in values] # specifically set for sc2 minimaps values = [tf.where(tf.equal(time_steps.step_type, StepType.LAST), 0.0, v) for v in values] values = [tf.where(tf.equal(time_steps.reward, 1.0), 0.0, v) for v in values] values = tf.add_n(values) return values
def testTwoOuterDimsUnknownShape(self): q_values = tf.convert_to_tensor( value=np.array([[[50, 51], [52, 53]]], dtype=np.float32)) actions = tf.convert_to_tensor(value=np.array([[1, 0]], dtype=np.int32)) values = common.index_with_actions(q_values, actions) self.assertAllClose([[51, 52]], self.evaluate(values))
def compute_loss(self, observations: types.NestedTensor, actions: types.NestedTensor, rewards: types.Tensor, weights: Optional[types.TensorOrArray] = None, training: bool = False) -> types.Tensor: """Computes loss for training the constraint network. Args: observations: A batch of observations. actions: A batch of actions. rewards: A batch of rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether the loss is being used for training. Returns: loss: A `Tensor` containing the loss for the training step. """ with tf.name_scope('constraint_loss'): sample_weights = weights if weights else 1 predicted_values, _ = self._constraint_network(observations, training=training) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32)) loss = self._error_loss_fn( rewards, action_predicted_values, sample_weights, reduction=tf.compat.v1.losses.Reduction.MEAN) return loss
def _compute_next_q_values(self, next_time_steps): """Compute the q value of the next state for TD error computation. Args: next_time_steps: A batch of next timesteps Returns: A tensor of Q values for the given next state. """ # TODO(b/117175589): Add binary tests for DDQN. next_target_q_values, _ = self._target_q_network( next_time_steps.observation, next_time_steps.step_type) batch_size = ( next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]) dummy_state = self._greedy_policy.get_initial_state(batch_size) # Find the greedy actions using our greedy policy. This ensures that masked # actions (and other logic) are respected. best_next_actions = self._greedy_policy.action( next_time_steps, dummy_state).action # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0 return common.index_with_actions( next_target_q_values, best_next_actions, multi_dim_actions=multi_dim_actions)
def _apply_action(self, action): self._reward = self._environment_dynamics.reward( self._observation, self._env_time) tf.compat.v1.assign_add(self._env_time, self._environment_dynamics.batch_size) return common.index_with_actions(self._reward, tf.cast(action, dtype=tf.int32))
def loss(self, observations, actions, rewards, weights=None): """Computes loss for reward prediction training. Args: observations: A batch of observations. actions: A batch of actions. rewards: A batch of rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. Returns: loss: A `LossInfo` containing the loss for the training step. Raises: ValueError: if the number of actions is greater than 1. """ with tf.name_scope('loss'): predicted_values, _ = self._reward_network(observations) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32)) loss = self._error_loss_fn(rewards, action_predicted_values, weights if weights else 1) return tf_agent.LossInfo(loss, extra=())
def _compute_next_q_values(self, next_time_steps, index, time): """spot and flex dispatches of last day (48) + MCP current hour + current state (spot/ flex)""" obs_indices = tf.convert_to_tensor( list(range(0, 48)) + [48 + time] + [71]) observation = tf.gather(next_time_steps.observation[index], indices=obs_indices, axis=-1) network_observation = observation next_target_q_values, _ = self.agent._target_q_network( network_observation, next_time_steps.step_type) batch_size = (next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]) dummy_state = self.agent._target_greedy_policy.get_initial_state( batch_size) # Find the greedy actions using our target greedy policy. This ensures that # action constraints are respected and helps centralize the greedy logic. next_individual_qmix_time_step = ts.get_individual_qmix_time_step( next_time_steps, index, time) greedy_actions = self.agent._target_greedy_policy.action( next_individual_qmix_time_step, dummy_state).action # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = tf.nest.flatten( self.agent._action_spec)[0].shape.rank > 0 return common.index_with_actions(next_target_q_values, greedy_actions, multi_dim_actions=multi_dim_actions)
def _compute_next_q_values(self, next_time_steps): """Compute the q value of the next state for TD error computation. Args: next_time_steps: A batch of next timesteps Returns: A tensor of Q values for the given next state. """ network_observation = next_time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = self._observation_and_action_constraint_splitter( network_observation) next_target_q_values, _ = self._target_q_network( network_observation, next_time_steps.step_type) batch_size = ( next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]) dummy_state = self._target_greedy_policy.get_initial_state(batch_size) # Find the greedy actions using our target greedy policy. This ensures that # action constraints are respected and helps centralize the greedy logic. greedy_actions = self._target_greedy_policy.action( next_time_steps, dummy_state).action # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0 return common.index_with_actions( next_target_q_values, greedy_actions, multi_dim_actions=multi_dim_actions)
def _compute_next_q_values(self, next_time_steps, info): """Compute the q value of the next state for TD error computation. Args: next_time_steps: A batch of next timesteps info: PolicyStep.info that may be used by other agents inherited from dqn_agent. Returns: A tensor of Q values for the given next state. """ del info # TODO(b/117175589): Add binary tests for DDQN. network_observation = next_time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = self._observation_and_action_constraint_splitter( network_observation) next_target_q_values, _ = self._target_q_network( network_observation, step_type=next_time_steps.step_type) batch_size = ( next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]) dummy_state = self._policy.get_initial_state(batch_size) # Find the greedy actions using our greedy policy. This ensures that action # constraints are respected and helps centralize the greedy logic. best_next_actions = self._policy.action(next_time_steps, dummy_state).action # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0 return common.index_with_actions( next_target_q_values, best_next_actions, multi_dim_actions=multi_dim_actions)
def compute_munchausen_td_targets(next_q_values, q_target_values, actions, rewards, multi_dim_actions, discounts, alpha, entropy_tau): next_max_v_values = tf.expand_dims(tf.reduce_max(next_q_values, 1), -1) tau_logsum_next = entropy_tau * tf.reduce_logsumexp( (next_q_values - next_max_v_values) / entropy_tau, axis=1) # batch x actions tau_logsum_next = tf.expand_dims(tau_logsum_next, -1) tau_logpi_next = next_q_values - next_max_v_values - tau_logsum_next pi_target = tf.nn.softmax(next_q_values / entropy_tau, 1) # valid_mask shape: (batch_size, ) q_target = discounts * tf.reduce_sum( (pi_target * (next_q_values - tau_logpi_next)), 1) # * valid_mask v_target_max = tf.expand_dims(tf.reduce_max(q_target_values, 1), -1) tau_logsum_target = entropy_tau * tf.reduce_logsumexp( (q_target_values - v_target_max) / entropy_tau, 1) tau_logsum_target = tf.expand_dims(tau_logsum_target, -1) tau_logpi_target = q_target_values - v_target_max - tau_logsum_target # munchausen addon uses the current state and actions munchausen_addon = common.index_with_actions( tau_logpi_target, tf.cast(actions, dtype=tf.int32), multi_dim_actions) #rewards = reward_scale_factor * next_time_steps.reward munchausen_reward = rewards + alpha * tf.clip_by_value( munchausen_addon, clip_value_max=0, clip_value_min=-2) td_targets = munchausen_reward + q_target return tf.stop_gradient(td_targets)
def _compute_q_values(self, policies, time_steps, actions): """Compute the q value of the current state/action for TD error computation. Args: policies: list of HeteroQPolicy object time_steps: A batch of current timesteps actions: A batch of actions Returns: A tensor of Q values for the given next state. """ q_values_seq, _ = self._sequential_network_activation(policies, time_steps, actions) action_key = 'raw' raw_actions = tf.unstack(actions[action_key], axis=1) # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = False values = [common.index_with_actions( self._append2logits(q_values), tf.cast(act, dtype=tf.int32), multi_dim_actions=multi_dim_actions) for q_values, act in zip(q_values_seq, raw_actions)] # due to dist.mode() in GreedyPolicy, 0 is selected for masked action. So we need to remove -inf values = [tf.where(tf.equal(v, NEG_INF), 0.0, v) for v in values] values = tf.add_n(values) return values
def _compute_q_values(self, time_steps, actions): q_values, _ = self._q_network(time_steps.observation, time_steps.step_type) # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = self._action_spec.shape.ndims > 0 return common.index_with_actions(q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions)
def testTwoOuterDimsUnknownShape(self): q_values = tf.placeholder(tf.float32, shape=[None, None, None]) actions = tf.placeholder(tf.int32, shape=[None, None]) values = common.index_with_actions(q_values, actions) with self.test_session() as sess: self.assertAllClose( [[51, 52]], sess.run(values, feed_dict={q_values: [[[50, 51], [52, 53]]], actions: [[1, 0]]}))
def checkCorrect(self, q_values, actions, expected_values, multi_dim_actions=False): q_values = tf.constant(q_values, dtype=tf.float32) actions = tf.constant(actions, dtype=tf.int32) selected_q_values = common.index_with_actions(q_values, actions, multi_dim_actions) selected_q_values_ = self.evaluate(selected_q_values) self.assertAllClose(selected_q_values_, expected_values)
def _compute_q_values(self, time_steps, actions): q_values, _ = self._q_network(time_steps.observation, time_steps.step_type) # Handle action_spec.shape=(), and shape=(1,) by using the # multi_dim_actions param. multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0 q_values = common.index_with_actions( q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions) return q_values
def _compute_q_values(self, time_steps, actions): network_observation = time_steps.observation if self._observation_anc_action_constraint_splitter: network_observation, _ = self._observation_anc_action_constraint_splitter( network_observation) q_values, _ = self._q_network(network_observation, time_steps.step_type) multi_dim_actions = self._action_spec.shape.rank > 0 return common.index_with_actions(q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions)
def _loss_using_reward_layer(self, observations: types.NestedTensor, actions: types.Tensor, rewards: types.Tensor, weights: Optional[types.Float] = None, training: bool = False) -> tf_agent.LossInfo: """Computes loss for reward prediction training. Args: observations: A batch of observations. actions: A batch of actions. rewards: A batch of rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether the loss is being used for training. Returns: loss: A `LossInfo` containing the loss for the training step. """ with tf.name_scope('loss'): encoded_observation, _ = self._encoding_network( observations, training=training) encoded_observation = tf.reshape( encoded_observation, shape=[-1, self._encoding_dim]) predicted_rewards = self._reward_layer( encoded_observation, training=training) chosen_actions_predicted_rewards = common.index_with_actions( predicted_rewards, tf.cast(actions, dtype=tf.int32)) loss = self._error_loss_fn(rewards, chosen_actions_predicted_rewards, weights if weights else 1) if self._summarize_grads_and_vars: with tf.name_scope('Per_arm_loss/'): for k in range(self._num_models): loss_mask_for_arm = tf.cast(tf.equal(actions, k), tf.float32) loss_for_arm = self._error_loss_fn( rewards, chosen_actions_predicted_rewards, weights=loss_mask_for_arm) tf.compat.v2.summary.scalar( name='loss_arm_' + str(k), data=loss_for_arm, step=self.train_step_counter) return tf_agent.LossInfo(loss, extra=())
def _compute_q_values(self, time_steps, actions): network_observation = time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = self._observation_and_action_constraint_splitter( network_observation) q_values, _ = self._q_network(network_observation, time_steps.step_type) # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = self._action_spec.shape.rank > 0 return common.index_with_actions( q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions)
def call(self, trajectory): """Update the constraint violations metric. Args: trajectory: A tf_agents.trajectory.Trajectory Returns: The arguments, for easy chaining. """ feasibility_prob_all_actions = self._constraint(trajectory.observation) feasibility_prob_selected_actions = common.index_with_actions( feasibility_prob_all_actions, tf.cast(trajectory.action, dtype=tf.int32)) self.constraint_violations.assign( tf.reduce_mean(1.0 - feasibility_prob_selected_actions)) return trajectory
def compute_dpp_td_targets(next_p_values, p_target_values, actions, rewards, multi_dim_actions, discounts, alpha, entropy_tau): boltzmann_p = tf.reduce_sum( tf.nn.softmax(p_target_values / entropy_tau, axis=1) * p_target_values, 1) p_target_values = common.index_with_actions( p_target_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions) action_gap = alpha * (p_target_values - boltzmann_p) next_boltzmann_p = tf.reduce_sum( tf.nn.softmax(next_p_values / entropy_tau, axis=1) * next_p_values, 1) td_targets = rewards + discounts * next_boltzmann_p + action_gap return tf.stop_gradient(td_targets)
def _compute_next_q_values(self, next_time_steps): network_observation = next_time_steps.observation if self._observation_anc_action_constraint_splitter: network_observation, _ = self._observation_anc_action_constraint_splitter( network_observation) next_target_q_values, _ = self._target_q_network( network_observation, next_time_steps.step_type) batch_size = (next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]) dummy_state = self._target_greedy_policy.get_initial_state(batch_size) greedy_actions = self._target_greedy_policy.action( next_time_steps, dummy_state).action multi_dim_actions = tf.nest.flatten( self._action_spec)[0].shape.rank > 0 return common.index_with_actions(next_target_q_values, greedy_actions, multi_dim_actions=multi_dim_actions)
def _compute_next_q_values(self, next_time_steps): """Compute the q value of the next state for TD error computation. Args: next_time_steps: A batch of next timesteps Returns: A tensor of Q values for the given next state. """ # TODO(b/117175589): Add binary tests for DDQN. next_q_values, _ = self._q_network(next_time_steps.observation, next_time_steps.step_type) best_next_actions = tf.cast(tf.argmax(input=next_q_values, axis=-1), dtype=tf.int32) next_target_q_values, _ = self._target_q_network( next_time_steps.observation, next_time_steps.step_type) multi_dim_actions = best_next_actions.shape.ndims > 1 return common.index_with_actions(next_target_q_values, best_next_actions, multi_dim_actions=multi_dim_actions)
def _compute_q_values(self, q_network, time_steps, discrete_actions): continuous_action_values, _ = self._actor_network( time_steps.observation, time_steps.step_type) # noisy_target_action_values = tf.nest.map_structure(self._add_noise_to_action, # target_action_values) time_step_obs = tf.nest.flatten( time_steps.observation) + [continuous_action_values] if isinstance(q_network, QNetwork): time_step_obs = tf.concat(time_step_obs, axis=-1) q_values, _ = q_network(time_step_obs, time_steps.step_type) # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = tf.nest.flatten( self._action_spec.q_network)[0].shape.ndims > 0 return common.index_with_actions( q_values, tf.cast(discrete_actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions), continuous_action_values
def __call__(self, observation, actions=None): """Returns the probability of input actions being feasible.""" predicted_quantiles, _ = self._constraint_network(observation, training=False) batch_dims = nest_utils.get_outer_shape( observation, self._time_step_spec.observation) if self._baseline_action_fn is not None: baseline_action = self._baseline_action_fn(observation) baseline_action.shape.assert_is_compatible_with(batch_dims) else: baseline_action = tf.zeros(batch_dims, dtype=tf.int32) predicted_quantiles_for_baseline_actions = common.index_with_actions( predicted_quantiles, tf.cast(baseline_action, dtype=tf.int32)) predicted_quantiles_for_baseline_actions = self._reshape_and_broadcast( predicted_quantiles_for_baseline_actions, tf.shape(predicted_quantiles)) is_satisfied = self._comparator_fn( predicted_quantiles, predicted_quantiles_for_baseline_actions) return tf.cast(is_satisfied, tf.float32)
def _compute_q_values(self, time_steps, actions, index, time, training=False): """spot and flex dispatches of last day (48) + MCP current hour + current state (spot/ flex)""" obs_indices = tf.convert_to_tensor( list(range(0, 48)) + [48 + time] + [71]) observation = tf.gather(time_steps.observation[index], indices=obs_indices, axis=-1) network_observation = observation q_values, _ = self.agent._q_network(network_observation, time_steps.step_type, training=training) # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = self.agent._action_spec.shape.rank > 0 actions = tf.reshape(actions, [-1, 1]) return common.index_with_actions(q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions)
def compute_loss(self, observations: types.NestedTensor, actions: types.NestedTensor, rewards: types.Tensor, weights: Optional[types.TensorOrArray] = None, training: bool = False) -> types.Tensor: """Computes loss for training the constraint network. Args: observations: A batch of observations. actions: A batch of actions. rewards: A batch of rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether the loss is being used for training. Returns: loss: A `Tensor` containing the loss for the training step. """ with tf.name_scope('constraint_loss'): sample_weights = weights if weights else 1 predicted_values, _ = self._constraint_network( observations, training=training) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32)) # Reduction is done outside of the loss function because non-scalar # weights with unknown shapes may trigger shape validation that fails # XLA compilation. return tf.reduce_mean( tf.multiply( self._error_loss_fn( rewards, action_predicted_values, reduction=tf.compat.v1.losses.Reduction.NONE), sample_weights))
def _compute_next_q_values(self, target_q_network, next_time_steps): """Compute the q value of the next state for TD error computation. Args: next_time_steps: A batch of next timesteps Returns: A tensor of Q values for the given next state. """ next_target_continuous_action_values, _ = self._target_actor_network( next_time_steps.observation, next_time_steps.step_type) noisy_target_action_values = tf.nest.map_structure( self._add_noise_to_action, next_target_continuous_action_values) time_step_obs = tf.nest.flatten( next_time_steps.observation) + [noisy_target_action_values] if isinstance(target_q_network, QNetwork): time_step_obs = tf.concat(time_step_obs, axis=-1) next_target_q_values, _ = target_q_network(time_step_obs, next_time_steps.step_type) batch_size = (next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0]) dummy_state = self._target_greedy_policy.get_initial_state(batch_size) # Find the greedy actions using our target greedy policy. This ensures that # masked actions are respected and helps centralize the greedy logic. greedy_discrete_actions = self._target_greedy_policy.action( next_time_steps, dummy_state).action # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. multi_dim_actions = tf.nest.flatten( self._action_spec.q_network)[0].shape.ndims > 0 return common.index_with_actions(next_target_q_values, greedy_discrete_actions, multi_dim_actions=multi_dim_actions)
def _single_objective_loss(self, objective_idx: int, observations: tf.Tensor, actions: tf.Tensor, single_objective_values: tf.Tensor, weights: types.Tensor = None, training: bool = False) -> tf.Tensor: """Computes loss for a single objective. Args: objective_idx: The index into `self._objective_networks` for a specific objective network. observations: A batch of observations. actions: A batch of actions. single_objective_values: A batch of objective values shaped as [batch_size] for the objective that the network indexed by `objective_idx` is predicting. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether the loss is being used for training. Returns: loss: A `Tensor` containing the loss for the training step. Raises: ValueError: if the number of actions is greater than 1. """ if objective_idx >= self._num_objectives or objective_idx < 0: raise ValueError( 'objective_idx should be between 0 and {}, but is {}'.format( self._num_objectives, objective_idx)) with tf.name_scope('loss_for_objective_{}'.format(objective_idx)): objective_network = self._objective_networks[objective_idx] sample_weights = weights if weights is not None else 1 if self._heteroscedastic[objective_idx]: predictions, _ = objective_network(observations, training=training) predicted_values = predictions.q_value_logits predicted_log_variance = predictions.log_variance action_predicted_log_variance = common.index_with_actions( predicted_log_variance, tf.cast(actions, dtype=tf.int32)) sample_weights = sample_weights * 0.5 * tf.exp( -action_predicted_log_variance) loss = 0.5 * tf.reduce_mean(action_predicted_log_variance) # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x) # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in # Bayesian Deep Learning for Computer Vision?." Advances in Neural # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977 else: predicted_values, _ = objective_network(observations, training=training) loss = tf.constant(0.0) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32)) loss += self._error_loss_fn(single_objective_values, action_predicted_values, sample_weights) return loss
def loss(self, observations, actions, rewards, weights=None, training=False): """Computes loss for reward prediction training. Args: observations: A batch of observations. actions: A batch of actions. rewards: A batch of rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether the loss is being used for training. Returns: loss: A `LossInfo` containing the loss for the training step. Raises: ValueError: if the number of actions is greater than 1. """ with tf.name_scope('loss'): sample_weights = weights if weights else 1 if self._heteroscedastic: predictions, _ = self._reward_network(observations, training=training) predicted_values = predictions.q_value_logits predicted_log_variance = predictions.log_variance action_predicted_log_variance = common.index_with_actions( predicted_log_variance, tf.cast(actions, dtype=tf.int32)) sample_weights = sample_weights * 0.5 * tf.exp( -action_predicted_log_variance) loss = 0.5 * tf.reduce_mean(action_predicted_log_variance) # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x) # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in # Bayesian Deep Learning for Computer Vision?." Advances in Neural # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977 else: predicted_values, _ = self._reward_network(observations, training=training) loss = tf.constant(0.0) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32)) # Apply Laplacian smoothing on the estimated rewards, if applicable. if self._laplacian_matrix is not None: smoothness_batched = tf.matmul( predicted_values, tf.matmul(self._laplacian_matrix, predicted_values, transpose_b=True)) loss += (self._laplacian_smoothing_weight * tf.reduce_mean( tf.linalg.tensor_diag_part(smoothness_batched) * sample_weights)) loss += self._error_loss_fn( rewards, action_predicted_values, sample_weights, reduction=tf.compat.v1.losses.Reduction.MEAN) return tf_agent.LossInfo(loss, extra=())