def f(model_output, target_category): # pylint: disable=invalid-name shapes.assert_same_shape(model_output, target_category) batch_size = model_output.shape[0] j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) j = -1.0/batch_size * jnp.squeeze(j) return j
def f(model_output, targets, weights): # pylint: disable=invalid-name shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(model_output, weights) l1_dist = jnp.abs(model_output - targets) smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) weighted_smooth_dist = weights * smooth_dist return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)
def policy_batches_stream(self): """Use the RLTask self._task to create inputs to the policy model.""" # For now TD-0 estimation of the value. TODO(pkozakowski): Support others? for np_trajectory in self._task.trajectory_batch_stream( self._policy_batch_size, epochs=self._replay_epochs, max_slice_length=self._max_slice_length, include_final_state=False, ): (q_values, actions, act_log_probs) = self._run_value_model(np_trajectory.observations, np_trajectory.dist_inputs) shapes.assert_same_shape(q_values, act_log_probs) # q_values shape: (batch_size, n_samples, length) if len(q_values.shape) != 3: raise ValueError( 'Q-values are expected to have shape [batch_size, ' + 'n_samples, length], got: %s' % str(q_values.shape)) if q_values.shape[1] != self._q_value_n_samples: raise ValueError( 'Q-values dimension 1 should = n_samples, %d != %d' % (q_values.shape[1], self._q_value_n_samples)) if q_values.shape[0] != self._policy_batch_size: raise ValueError( 'Q-values dimension 0 should = policy batch size, ' + '%d!=%d' % (q_values.shape[1], self._policy_batch_size)) mask = np_trajectory.mask mask = np.reshape(mask, [mask.shape[0], 1] + list(mask.shape[1:])) mask = jnp.broadcast_to(mask, q_values.shape) shapes.assert_same_shape(mask, q_values) yield (np_trajectory.observations, actions, q_values, act_log_probs, mask)
def f(model_output, targets, weights): # pylint: disable=invalid-name predictions = jnp.argmax(model_output, axis=-1) shapes.assert_same_shape(predictions, targets) position_is_padding = jnp.equal(weights, 0) position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets), position_is_padding) sequence_is_accurate = jnp.all(position_is_accurate, axis=-1) return jnp.average(sequence_is_accurate)
def f(values, actions, returns, mask): ind_0, ind_1 = np.indices(actions.shape) # We calculate length using the shape of returns # and adequatly remove a superflous slice of values. # An analogous operation is done in value_batches_stream. length = returns.shape[1] values = values[:, :length, :] selected_values = values[ind_0, ind_1, actions] shapes.assert_same_shape(selected_values, returns) shapes.assert_same_shape(selected_values, mask) return jnp.sum(selected_values) / jnp.sum(mask)
def f(model_output, targets, weights): # pylint: disable=invalid-name """Returns elementwise-weighted L2 norm of `model_output - targets`. Args: model_output: Output from one batch, treated as an unanalyzed tensor. targets: Tensor of same shape as `model_output` containing element-wise target values. weights: Tensor of same shape as `model_output` and `targets`. """ shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(targets, weights) l2 = weights * (model_output - targets)**2 return jnp.sum(l2) / jnp.sum(weights)
def f(values, actions, returns, mask): ind_0, ind_1 = np.indices(actions.shape) # We calculate length using the shape of returns # and adequatly remove a superflous slice of values. # An analogous operation is done in value_batches_stream. length = returns.shape[1] values = values[:, :length, :] selected_values = values[ind_0, ind_1, actions] shapes.assert_same_shape(selected_values, returns) shapes.assert_same_shape(selected_values, mask) if self._smoothl1loss: return tl.SmoothL1Loss().forward((selected_values, returns, mask)) else: return tl.L2Loss().forward((selected_values, returns, mask))
def f(model_output, targets, weights): # pylint: disable=invalid-name """Returns weighted sum-of-squared-errors for `model_output` vs. `targets`. Args: model_output: Output from one batch, typically a 2- or 3-d array of float-valued elements. targets: Tensor of same shape as `model_output` containing element-wise target values. weights: Tensor of same shape as `model_output` and `targets`, containing element-wise weight values. """ shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(targets, weights) weighted_sse = weights * (model_output - targets)**2 return jnp.sum(weighted_sse) / jnp.sum(weights)
def smoothl1loss(model_output, targets, weights): # pylint: disable=invalid-name r"""Returns weighted smooth L1 norm of `model_output - targets`. The smooth L1 loss, also known as the Huber loss, is defined as: .. math:: z_i = \begin{cases} 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ |x_i - y_i| - 0.5, & \text{otherwise } \end{cases} Args: model_output: Output from one batch, treated as an unanalyzed tensor. targets: Tensor of same shape as `model_output` containing element-wise target values. weights: Tensor of same shape as `model_output` and `targets`, containing element-wise weight values. """ shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(targets, weights) l1_dist = jnp.abs(model_output - targets) smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) shapes.assert_same_shape(smooth_dist, weights) weighted_smooth_dist = weights * smooth_dist return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)
def policy_batches_stream(self): """Use the RLTask self._task to create inputs to the policy model.""" # For now TD-0 estimation of the value. TODO(pkozakowski): Support others? for np_trajectory in self._task.trajectory_batch_stream( self._policy_batch_size, epochs=self._replay_epochs, max_slice_length=self._max_slice_length, include_final_state=False, ): (q_values, actions) = self._run_value_model( np_trajectory.observations, np_trajectory.dist_inputs ) # TODO(pkozakowski): Try max here. values = jnp.mean(q_values, axis=0) if len(values.shape) != 2: raise ValueError('Values are expected to have shape ' + '[batch_size, length], got: %s' % str(values.shape)) if values.shape[0] != self._policy_batch_size: raise ValueError('Values first dimension should = policy batch size, ' + '%d != %d' %(values.shape[0], self._policy_batch_size)) # q_values shape: (n_samples, batch_size, length) # values shape: (batch_size, length) # Computing advantages by broadcasting over n_samples. advantages = q_values - values mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape) shapes.assert_shape_equals( advantages, (self._q_value_n_samples,) + values.shape ) shapes.assert_same_shape(mask, advantages) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. advantages = jnp.swapaxes(advantages, 0, 1) mask = jnp.swapaxes(mask, 0, 1) yield (np_trajectory.observations, actions, advantages, mask, mask)
def L2Loss(inputs): y_hat, y, mask = inputs shapes.assert_same_shape(y_hat, y) shapes.assert_same_shape(y, mask) l2 = mask * (y_hat - y)**2 return np.sum(l2) / np.sum(mask)
def f(model_output, targets): # pylint: disable=invalid-name predictions = jnp.argmax(model_output, axis=-1) shapes.assert_same_shape(predictions, targets) n_total = predictions.size n_correct = jnp.sum(jnp.equal(predictions, targets)) return n_correct / n_total
def RawL2Loss(inputs, **unused_kwargs): y_hat, y = inputs shapes.assert_same_shape(y_hat, y) return np.mean((y_hat - y)**2)
def f(model_output, targets, weights): # pylint: disable=invalid-name shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(model_output, weights) weighted_sse = weights * (model_output - targets)**2 return jnp.sum(weighted_sse) / jnp.sum(weights)
def MaskedL2Loss(inputs, **unused_kwargs): y_hat, y, mask = inputs shapes.assert_same_shape(y_hat, y) shapes.assert_same_shape(y, mask) l2 = mask * (y_hat - y)**2 return np.sum(l2) / np.sum(mask)
def f(model_output, targets, weights): # pylint: disable=invalid-name predictions = jnp.argmax(model_output, axis=-1) shapes.assert_same_shape(predictions, targets) ones_and_zeros = jnp.equal(predictions, targets) return jnp.sum(ones_and_zeros * weights) / jnp.sum(weights)
def f(y_hat, y, mask): # pylint: disable=invalid-name shapes.assert_same_shape(y_hat, y) shapes.assert_same_shape(y, mask) l2 = mask * (y_hat - y)**2 return np.sum(l2) / np.sum(mask)