def ClippedObjectiveMean( dist_inputs, values, returns, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) clipped_objective = rl_layers.ClippedObjective( probs_ratio, advantages, epsilon=self._epsilon) return jnp.mean(clipped_objective)
def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean( (returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def f(values, weights): # pylint: disable=invalid-name # This function assumes weights are 0 or 1. # Then compute 1: not-correct, 0: correct or masked not_correct = (1.0 - values) * weights axis_to_sum = list(range(1, len(not_correct.shape))) # Summing not-correct on all axes but batch. We're summing 0s and 1s, # so the sum is 0 if it's all 0 and >=1 in all other cases. not_correct_seq = np.sum(not_correct, axis=axis_to_sum) # Sequence is correct if not_correct_seq is 0, reverting here. correct_seq = 1.0 - np.minimum(1.0, not_correct_seq) return np.mean(correct_seq) # Mean over batch.
def _WeightedSequenceMean(inputs, **unused_kwargs): """Returns a layer to compute weighted seqeunce accuracy mean.""" values, weights = inputs # This function assumes weights are 0 or 1. not_correct = (1.0 - values) * weights # 1: not-correct, 0: correct or masked axis_to_sum = list(range(1, len(not_correct.shape))) # Summing not-correct on all axes but batch. We're summing 0s and 1s, # so the sum is 0 if it's all 0 and >=1 in all other cases. not_correct_seq = np.sum(not_correct, axis=axis_to_sum) # Sequence is correct if not_correct_seq is 0, reverting here. correct_seq = 1.0 - np.minimum(1.0, not_correct_seq) return np.mean(correct_seq) # Mean over batch.
def AWRJointLoss(x, **unused_kwargs): # pylint: disable=invalid-name preds, values, returns, actions, mask = x advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean( (returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def A2CObjective(dist_inputs, values, returns, actions, mask, log_prob_fun, normalize_advantages): """Definition of the Advantage Actor Critic (A2C) loss.""" returns = returns.squeeze() values = values.squeeze() new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 return -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
def predict(x, weights, state, rng): """Predict function JIT-compileds and parallelized as requested.""" res, state = _combine_devices(model_predict( reshape_by_device(x, n_devices), weights, state, jnp.stack(math.random.split(rng, n_devices)))) if do_mean: return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state else: return res, state
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None): """Actor loss.""" # log_probab_actions_new's shape is (AB, 1, #C, #A), AB is actor batch. lp = jnp.squeeze(log_probab_actions_new, axis=1) AB, NC = actions.shape # pylint: disable=invalid-name log_probs = lp[jnp.arange(AB)[:, None], jnp.arange(NC)[None, :], actions] # TODO(afrozm): Clarify this. # log_probs are shaped (AB, #C), however advantage_weights are (AB,) return -1.0 * jnp.mean(log_probs * advantage_weights[:, None]), state
def f(dist_inputs, values, returns, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) # advantages are of the shape [128,1,1] # and probs_ratio are of the shape [128,1] advantages = advantages.squeeze(axis=2) clipped_objective = rl_layers.ClippedObjective( probs_ratio, advantages, epsilon=self._epsilon) return jnp.mean(clipped_objective)
def f(dist_inputs, values, returns, actions, old_log_probs, mask): """A2C objective mean.""" del old_log_probs a2c_objective = rl_layers.A2CObjective( dist_inputs, values, returns, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) return jnp.mean(a2c_objective)
def f(dist_inputs, values, returns, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" ppo_objective = rl_layers.PPOObjective( dist_inputs, values, returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) return jnp.mean(ppo_objective)
def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun): """Probability Ratio from the PPO algorithm.""" # TODO(henrykm): Clarify the old_log_probs and squeezing # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion approximate_kl_divergence = 0.5 * \ jnp.mean(new_log_probs - old_log_probs) ** 2 return approximate_kl_divergence
def fn(dist_inputs, actions, q_values, act_log_probs, mask): del dist_inputs, actions, mask q_values = jnp.swapaxes(q_values, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) if self._sample_all_discrete_actions: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) else: values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples if preprocess: advantages = self._preprocess_advantages(advantages) return advantages
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages): """Definition of the Advantage Actor Critic (A2C) loss.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape int32[128,1,1] # actions of the shape int32[128,1] # and mask of the shape float32[128,1] # We have to squeeze values and returns, because we # are planning to compute (return - values) * new_log_probs * mask # and all of them should be of the same dimension values = values.squeeze(axis=2) returns = returns.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert values.shape == mask.shape, ( f'values.shape was {values.shape} and mask.shape was {mask.shape}') assert returns.shape[0] == dist_inputs.shape[0], ( f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was ' f'{dist_inputs.shape[0]}') new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) assert new_log_probs.shape == mask.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was ' f'{mask.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert new_log_probs.shape == advantages.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') # One of the motivation to the squeezes and assertions is to # avoid [128,1] * [128,1,1] * [128] multiplications in the definition # of the a2c objective - we insist on the same shapes a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask) return a2c_objective
def f(new_log_probs, advantages, old_log_probs, mask): # new_log_probs of the shape float32[128,1] # advantages of the shape int32[128,1] # old_log_probs of the shape int32[128,1] # mask of the shape int32[128,1] if new_log_probs.shape != advantages.shape: raise ValueError('New log-probs and advantages shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, advantages.shape)) if new_log_probs.shape != old_log_probs.shape: raise ValueError('New log-probs and old log-probs shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, old_log_probs.shape)) if new_log_probs.shape != mask.shape: raise ValueError('New log-probs and mask shapes should be the same' ', %s != %s' % (new_log_probs.shape, mask.shape)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) if advantages.shape != probs_ratio.shape: raise ValueError('New log-probs and old log probs shapes ' 'should be the same, %s != %s' % (advantages.shape, probs_ratio.shape)) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages if unclipped_objective.shape != probs_ratio.shape: raise ValueError('unclipped_objective and clipped_objective shapes ' 'should be the same, %s != %s' % ( unclipped_objective.shape, clipped_objective.shape)) ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) if ppo_objective.shape != mask.shape: raise ValueError('ppo_objective and mask shapes ' 'should be the same, %s != %s' % ( ppo_objective.shape, mask.shape)) ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) entropy_vec = self._policy_dist.entropy( new_log_probs) * self._entropy_coeff entropy_loss = jnp.mean(entropy_vec) combined_loss = ppo_loss - entropy_loss return combined_loss
def ProbsRatioMean(x, **unused_kwargs): """Probability Ratio Mean from the PPO algorithm.""" dist_inputs, _, _, actions, old_log_probs = x new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) return jnp.mean(probs_ratio)
def Mean(axis=-1, keepdims=False): """Returns a layer that computes mean values using one tensor axis. `Mean` uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (`keepdims=True`), or that axis can be removed from the overall tensor (default `keepdims=False`), lowering the rank of the tensor by one. Args: axis: Axis along which values are grouped for computing a mean. keepdims: If `True`, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis. """ return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
def forward(self, inputs, weights): gamma, beta, epsilon_l = weights epsilon = self._init_epsilon if epsilon_l is not base.EMPTY_WEIGHTS: epsilon += np.abs(epsilon_l[0]) # Omit B and C axis = tuple(range(1, len(np.shape(inputs)) - 1)) # (B, 1, 1, C) nu2 = np.mean(inputs**2, axis=axis, keepdims=True) # (B, W, H, C) xhat = inputs / np.sqrt(nu2 + epsilon) return gamma * xhat + beta
def PPOObjective(dist_inputs, values, returns, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages): """PPO Objective.""" # Returns and values are arriving with two extra dimensions # TODO(henrykm): remove these dimensions at an earlier stage? returns = returns.squeeze() values = values.squeeze() probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 unclipped_objective = UnclippedObjective(probs_ratio, advantages) clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) return ppo_objective
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): """Unclipped objective Mean from the PPO algorithm.""" del dones, rewards advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) # advantages are of the shape [128,1,1] # and probs_ratio are of the shape [128,1] advantages = advantages.squeeze(axis=2) unclipped_objective = rl_layers.UnclippedObjective( probs_ratio, advantages) return jnp.mean(unclipped_objective)
def PPOObjective(dist_inputs, values, returns, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages): """PPO Objective.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # actions of the shape int32[128,1] # and old_log_probs of the shape float32[128,1] returns = returns.squeeze(axis=2) values = values.squeeze(axis=2) assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert returns.shape == old_log_probs.shape, ( f'returns.shape was {returns.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) assert probs_ratio.shape == old_log_probs.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert old_log_probs.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') unclipped_objective = UnclippedObjective(probs_ratio, advantages) assert unclipped_objective.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and' f'unclipped_objective.shape was {unclipped_objective.shape}') clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) assert clipped_objective.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and' f'clipped_objective.shape was {clipped_objective.shape}') ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) assert ppo_objective.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and' f'ppo_objective.shape was {ppo_objective.shape}') return ppo_objective
def critic_loss(observations, target_values, value_predictions_new, state=None): """Critic loss.""" # There is no padding involved here, these are all observations. (batch, *obs_shape) = observations.shape del obs_shape if (batch, ) != target_values.shape: raise ValueError(f'batch dimension is not the same: obs batch {batch}' f' vs target values batch {target_values.shape[0]}') # TODO(afrozm): In the reference implementation, they pass the target through # a trained normalizer before subtracting. loss = 0.5 * jnp.mean(jnp.square(target_values - value_predictions_new)) return loss, state
def policy_inputs(self, trajectory, values): """Create inputs to policy model from a TrajectoryNp and values.""" # How much TD to use is determined by the added policy slice length, # as the policy batches need to be this much longer to calculate TD. advantages = self._advantage_estimator( trajectory.rewards, trajectory.returns, values, gamma=self._task.gamma, n_extra_steps=self._added_policy_slice_length, ) if self._advantage_normalization: advantages = ( (advantages - jnp.mean(advantages)) / (jnp.std(advantages) + self._advantage_normalization_epsilon)) # Observations should be the same length as advantages - so if we are # using n_extra_steps, we need to trim the length to match. obs = trajectory.observations[:, :advantages.shape[1]] act = trajectory.actions[:, :advantages.shape[1]] old_logps = trajectory.log_probs[:, :advantages.shape[1]] mask = trajectory.mask[:, :advantages. shape[1]] # Mask to zero-out padding. # Shape checks to help debugging. if len(advantages.shape) != 2: raise ValueError('Advantages are expected to have shape ' + '[batch_size, length], got: %s' % str(advantages.shape)) if act.shape[0:2] != advantages.shape: raise ValueError( 'First 2 dimensions of actions should be the same as in ' 'advantages, %s != %s' % (act.shape[0:2], advantages.shape)) if obs.shape[0:2] != advantages.shape: raise ValueError( 'First 2 dimensions of observations should be the same ' 'as in advantages, %s != %s' % (obs.shape[0:2], advantages.shape)) if old_logps.shape != advantages.shape: raise ValueError( 'Old log-probs and advantages shapes should be the same' ', %s != %s' % (old_logps.shape, advantages.shape)) if mask.shape != advantages.shape: raise ValueError('Mask and advantages shapes should be the same' ', %s != %s' % (mask.shape, advantages.shape)) return (obs, act, advantages, old_logps, mask)
def LossInput(dist_inputs, actions, q_values, act_log_probs, mask): # pylint: disable=invalid-name """Calculates action log probabilities and normalizes advantages.""" # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) q_values = jnp.swapaxes(q_values, 0, 1) mask = jnp.swapaxes(mask, 0, 1) actions = jnp.swapaxes(actions, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting? # Reweight: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples advantages = self._preprocess_advantages(advantages) # Broadcast inputs and calculate log-probs dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) log_probs = self._policy_dist.log_prob(dist_inputs, actions) return (log_probs, advantages, act_log_probs, mask)
def policy_metrics(self): metrics = { 'policy_loss': self.policy_loss, 'advantage_mean': tl.Serial( self._policy_inputs_to_advantages(False), tl.Fn('Mean', lambda x: jnp.mean(x)) # pylint: disable=unnecessary-lambda ), 'advantage_std': tl.Serial( self._policy_inputs_to_advantages(False), tl.Fn('Std', lambda x: jnp.std(x)) # pylint: disable=unnecessary-lambda ) } metrics.update( awr_metrics( self._beta, preprocess_layer=self._policy_inputs_to_advantages(True))) return metrics
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = self._for_n_devices( math.nested_map(np.array, self.nontrainable_params)) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. (weights, slots, stat), self._model_state, self._rngs = self._jit_update_fn( self._step, opt_state, batch, self._model_state, self._rngs) self._model_state = self._map_to_state_dicts(self._state_dicts_update) self._opt_state = opt_state._replace(weights=weights, slots=slots) if self._should_log_now(): for name, value in stat.items(): scalar_value = np.mean(value) # On multiple devices, take the mean. self._train_sw.scalar('training/' + name, scalar_value, step=self._step) self._step += 1
def f(log_probs, advantages, old_log_probs, mask): del old_log_probs # Not used in A2C. # log_probs of the shape float32[128,1] # advantages of the shape int32[128,1] # mask of the shape int32[128,1] if log_probs.shape != advantages.shape: raise ValueError('New log-probs and advantages shapes ' 'should be the same, %s != %s' % (log_probs.shape, advantages.shape)) if log_probs.shape != mask.shape: raise ValueError('New log-probs and mask shapes should be the same' ', %s != %s' % (log_probs.shape, mask.shape)) a2c_objective = -jnp.sum(log_probs * advantages * mask) / jnp.sum(mask) entropy_vec = self._policy_dist.entropy(log_probs) * self._entropy_coeff entropy_loss = jnp.mean(entropy_vec) combined_loss = a2c_objective - entropy_loss return combined_loss
def policy_batches_stream(self): """Use the RLTask self._task to create inputs to the policy model.""" # For now TD-0 estimation of the value. TODO(pkozakowski): Support others? for np_trajectory in self._task.trajectory_batch_stream( self._policy_batch_size, epochs=self._replay_epochs, max_slice_length=self._max_slice_length, include_final_state=False, ): (q_values, actions) = self._run_value_model( np_trajectory.observations, np_trajectory.dist_inputs ) # TODO(pkozakowski): Try max here. values = jnp.mean(q_values, axis=0) if len(values.shape) != 2: raise ValueError('Values are expected to have shape ' + '[batch_size, length], got: %s' % str(values.shape)) if values.shape[0] != self._policy_batch_size: raise ValueError('Values first dimension should = policy batch size, ' + '%d != %d' %(values.shape[0], self._policy_batch_size)) # q_values shape: (n_samples, batch_size, length) # values shape: (batch_size, length) # Computing advantages by broadcasting over n_samples. advantages = q_values - values mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape) shapes.assert_shape_equals( advantages, (self._q_value_n_samples,) + values.shape ) shapes.assert_same_shape(mask, advantages) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. advantages = jnp.swapaxes(advantages, 0, 1) mask = jnp.swapaxes(mask, 0, 1) yield (np_trajectory.observations, actions, advantages, mask, mask)
def test_custom_id_grad(self): class IdWithIdGrad(base.Layer): def forward(self, x, weights): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (inputs, ()) layer = IdWithIdGrad() rng = math.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = math.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = math.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = _combine_devices( model_predict(reshape_by_device(x, n_devices), weights, state, np.stack(math.random.split(rng, n_devices)))) return math.nested_map(lambda y: np.mean(y, axis=0), res), state