def update(self, step, grads, weights, slots, opt_params): updates = [] learning_rate = opt_params['learning_rate'] beta1 = opt_params['beta1'] decay_rate = opt_params['decay_rate'] clipping_threshold = opt_params['clipping_threshold'] weight_decay_rate = opt_params['weight_decay_rate'] weight_decay_n_steps = opt_params['weight_decay_n_steps'] weight_decay_rate = jnp.where( weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it weight_decay_rate, (weight_decay_rate * jnp.maximum(weight_decay_n_steps - step, 0.0) / jnp.maximum(weight_decay_n_steps, 0.0))) epsilon1 = opt_params['epsilon1'] epsilon2 = opt_params['epsilon2'] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads if self._factored and len(weights.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = (decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1)) new_v_col = (decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2)) updates.extend([new_v_row, new_v_col]) row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) row_factor = (row_mean / (new_v_row + epsilon1))**0.5 col_factor = (new_v_col + epsilon1)**-0.5 y = (grads * jnp.expand_dims(row_factor, axis=-1) * jnp.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v + epsilon1)**-0.5 if self._do_clipping: clipping_denom = (jnp.maximum( 1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_weights = (1 - weight_decay_rate) * weights - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_weights.astype(weights.dtype), updates
def mean_or_pmean(n_devices, x, axis=None): """jnp.mean or pmean. `x` is a distributed value. Directly calling jnp.mean on `x` means stacking x's components together to form a large array and then doing jnp.mean on it. In TF, stacking `x` will introduce D2H copy, so we use a collective (pmean) here instead of directly calling jnp.mean for TF. Args: n_devices: number of devices. x: a distributed array. axis: the axis to reduce. Can only be 0 or None. Returns: A local array. """ if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: if axis not in (None, 0): raise ValueError('axis can only be None or 0') x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices if axis is None: x = jnp.mean(x) return x else: return jnp.mean(x, axis=axis)
def forward(self, x): scale, bias = self.weights mean = jnp.mean(x, axis=-1, keepdims=True) centered = x - mean variance = jnp.mean(centered * centered, axis=-1, keepdims=True) norm_inputs = centered / jnp.sqrt(variance + self._epsilon) return norm_inputs * scale + bias
def forward(self, x): scale, bias = self.weights mean = jnp.mean(x, axis=-1, keepdims=True) sub = x - mean variance = jnp.mean(sub * sub, axis=-1, keepdims=True) norm_inputs = sub / jnp.sqrt(variance + self._epsilon) return norm_inputs * scale + bias
def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. opt_param_updates = self._for_n_devices( {'learning_rate': np.array(self.learning_rate)}) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. weights, slots, opt_params = opt_state (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( (weights, slots), self._step, opt_params, batch, self._model_state, self._rngs) self._opt_state = opt_state._replace(weights=weights, slots=slots) if self._should_log_now(): for name, value in stat.items(): # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here # with a device put array error complaining that it should be an array. # On multiple devices, take the mean. scalar_value = np.mean(np.array(value)) self._train_sw.scalar('training/' + name, scalar_value, step=self._step) self._step += 1
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = self._for_n_devices( fastmath.nested_map(np.array, self.nontrainable_params)) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. weights, slots, opt_params = opt_state (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( (weights, slots), self._step, opt_params, batch, self._model_state, self._rngs) self._opt_state = opt_state._replace(weights=weights, slots=slots) if self._should_log_now(): for name, value in stat.items(): scalar_value = np.mean( value) # On multiple devices, take the mean. self._train_sw.scalar('training/' + name, scalar_value, step=self._step) self._step += 1
def test_custom_zero_grad(self, backend): class IdWithZeroGrad(tl.Layer): def forward(self, x): return x @property def has_backward(self): return True def backward(self, inputs, output, grad, weights, state, new_state, rng): return (jnp.zeros_like(grad), ()) with fastmath.use_backend(backend): layer = IdWithZeroGrad() rng = fastmath.random.get_prng(0) input_signature = shapes.ShapeDtype((9, 17)) random_input = fastmath.random.uniform(rng, input_signature.shape, minval=-1.0, maxval=1.0) layer.init(input_signature) f = lambda x: jnp.mean(layer(x)) grad = fastmath.grad(f)(random_input) self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
def _preprocess_advantages(self, advantages): if self._advantage_normalization: advantages = ( (advantages - jnp.mean(advantages)) / (jnp.std(advantages) + self._advantage_normalization_epsilon) ) return advantages
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages): """PPO Objective.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape float32[128,1,1] # rewards of the shape int32[128,1,1] # actions of the shape int32[128,1] # and old_log_probs of the shape float32[128,1] returns = returns.squeeze(axis=2) values = values.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert returns.shape == old_log_probs.shape, ( f'returns.shape was {returns.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) assert probs_ratio.shape == old_log_probs.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert old_log_probs.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') unclipped_objective = UnclippedObjective(probs_ratio, advantages) assert unclipped_objective.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and' f'unclipped_objective.shape was {unclipped_objective.shape}') clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) assert clipped_objective.shape == advantages.shape, ( f'clipped_objective.shape was {clipped_objective.shape} and' f'advantages.shape was {advantages.shape}') ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) assert ppo_objective.shape == advantages.shape, ( f'ppo_objective.shape was {ppo_objective.shape} and' f'advantages.shape was {advantages.shape}') return ppo_objective
def ClipFraction(dist_inputs, actions, old_log_probs): """Probability Ratio Mean from the PPO algorithm.""" probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) return jnp.mean(jnp.abs(probs_ratio - 1) > self._epsilon)
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" ppo_objective = rl_layers.PPOObjective( dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) return jnp.mean(ppo_objective)
def f(model_output, targets): # pylint: disable=invalid-name beta2 = beta ** 2 predictions = jnp.argmax(model_output, axis=-1) n_categories = model_output.shape[-1] f_scores = jnp.empty(0) for k in range(initial_category_index, n_categories): _, _, _, precision, recall = _precision_recall(predictions, targets, k) f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) return jnp.mean(f_scores)
def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun): """Probability Ratio from the PPO algorithm.""" new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) assert new_log_probs.shape == old_log_probs.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') approximate_kl_divergence = 0.5 * \ jnp.mean(new_log_probs - old_log_probs) ** 2 return approximate_kl_divergence
def _aggregate_values(self, values, aggregate_max, act_log_probs): if self._q_value: if aggregate_max: values = jnp.max(values, axis=1) elif self._sample_all_discrete_actions: values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) else: values = jnp.mean(values, axis=1) return np.array(values) # Move the values to CPU.
def calculate_weights(self, advantages): """Calculates advantage-based weights for log loss in policy training.""" if self._advantage_normalization: # Normalize advantages. advantages -= jnp.mean(advantages) advantage_std = jnp.std(advantages) advantages /= advantage_std + self._advantage_normalization_epsilon weights = self._weight_fn(advantages) assert weights.shape == advantages.shape return weights
def predict(x, weights, state, rng): """Predict function JIT-compiled and parallelized as requested.""" res, state = _combine_devices( model_predict(reshape_by_device(x, n_devices), weights, state, jnp.stack(fastmath.random.split(rng, n_devices)))) if do_mean: return fastmath.nested_map(lambda y: jnp.mean(y, axis=0), res), state else: return res, state
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """A2C objective mean.""" # TODO(henrykm): include dones, rewards del old_log_probs a2c_objective = rl_layers.A2CObjective( dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) return jnp.mean(a2c_objective)
def f(values, weights): # pylint: disable=invalid-name # This function assumes weights are 0 or 1. # Then compute 1: not-correct, 0: correct or masked not_correct = (1.0 - values) * weights axis_to_sum = list(range(1, len(not_correct.shape))) # Summing not-correct on all axes but batch. We're summing 0s and 1s, # so the sum is 0 if it's all 0 and >=1 in all other cases. not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) # Sequence is correct if not_correct_seq is 0, reverting here. correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) return jnp.mean(correct_seq) # Mean over batch.
def mean_or_pmean(n_devices, x, axis=None): """Computes the mean of a distributed value ``x``. Args: n_devices: Number of devices. x: Distributed array. axis: Axis along which to compute means; can only be ``0`` or ``None``. Returns: A local array. """ if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: if axis not in (None, 0): raise ValueError('axis can only be None or 0') x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices if axis is None: x = jnp.mean(x) return x else: return jnp.mean(x, axis=axis)
def fn(dist_inputs, actions, q_values, act_log_probs, mask): del dist_inputs, actions, mask q_values = jnp.swapaxes(q_values, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) if self._sample_all_discrete_actions: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) else: values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples if preprocess: advantages = self._preprocess_advantages(advantages) return advantages
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages): """Definition of the Advantage Actor Critic (A2C) loss.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape int32[128,1,1] # actions of the shape int32[128,1] # and mask of the shape float32[128,1] # We have to squeeze values and returns, because we # are planning to compute (return - values) * new_log_probs * mask # and all of them should be of the same dimension values = values.squeeze(axis=2) returns = returns.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert values.shape == mask.shape, ( f'values.shape was {values.shape} and mask.shape was {mask.shape}') assert returns.shape[0] == dist_inputs.shape[0], ( f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was ' f'{dist_inputs.shape[0]}') new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) assert new_log_probs.shape == mask.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was ' f'{mask.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert new_log_probs.shape == advantages.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') # One of the motivation to the squeezes and assertions is to # avoid [128,1] * [128,1,1] * [128] multiplications in the definition # of the a2c objective - we insist on the same shapes a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask) return a2c_objective
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" del dones, rewards advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) # advantages are of the shape [128,1,1] # and probs_ratio are of the shape [128,1] advantages = advantages.squeeze(axis=2) clipped_objective = rl_layers.ClippedObjective( probs_ratio, advantages, epsilon=self._epsilon) return jnp.mean(clipped_objective)
def _aggregate_values(self, values, aggregate, act_log_probs): # Normalize the Q-values before aggragetion, so it can adapt to the scale # of the returns. This does not affect mean and max aggregation. scale = 1 epsilon = 1e-5 if self._q_value_normalization == 'std': scale = jnp.std(values) + epsilon elif self._q_value_normalization == 'abs': scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon values /= scale temp = self._q_value_temperature if self._q_value: assert values.shape[:2] == (self._value_batch_size, self._q_value_n_samples) if aggregate == 'max': # max_a Q(s, a) values = jnp.max(values, axis=1) elif aggregate == 'softmax': # sum_a (Q(s, a) * w(s, a)) # where w(s, .) = softmax (Q(s, .) / T) weights = tl.Softmax(axis=1)(values / temp) values = jnp.sum(values * weights, axis=1) elif aggregate == 'logsumexp': # log(mean_a exp(Q(s, a) / T)) * T n = values.shape[1] values = (fastmath.logsumexp(values / temp, axis=1) - jnp.log(n)) * temp else: assert aggregate == 'mean' # mean_a Q(s, a) if self._sample_all_discrete_actions: values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) else: values = jnp.mean(values, axis=1) # Re-scale the Q-values after aggregation. values *= scale return np.array(values) # Move the values to CPU.
def f(new_log_probs, advantages, old_log_probs, mask): # new_log_probs of the shape float32[128,1] # advantages of the shape int32[128,1] # old_log_probs of the shape int32[128,1] # mask of the shape int32[128,1] if new_log_probs.shape != advantages.shape: raise ValueError('New log-probs and advantages shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, advantages.shape)) if new_log_probs.shape != old_log_probs.shape: raise ValueError('New log-probs and old log-probs shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, old_log_probs.shape)) if new_log_probs.shape != mask.shape: raise ValueError('New log-probs and mask shapes should be the same' ', %s != %s' % (new_log_probs.shape, mask.shape)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) if advantages.shape != probs_ratio.shape: raise ValueError('New log-probs and old log probs shapes ' 'should be the same, %s != %s' % (advantages.shape, probs_ratio.shape)) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages if unclipped_objective.shape != probs_ratio.shape: raise ValueError('unclipped_objective and clipped_objective shapes ' 'should be the same, %s != %s' % ( unclipped_objective.shape, clipped_objective.shape)) ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) if ppo_objective.shape != mask.shape: raise ValueError('ppo_objective and mask shapes ' 'should be the same, %s != %s' % ( ppo_objective.shape, mask.shape)) ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) entropy_vec = self._policy_dist.entropy( new_log_probs) * self._entropy_coeff entropy_loss = jnp.mean(entropy_vec) combined_loss = ppo_loss - entropy_loss return combined_loss
def forward(self, inputs): gamma, beta, epsilon_l = self.weights epsilon = self._init_epsilon if epsilon_l is not base.EMPTY_WEIGHTS: epsilon += jnp.abs(epsilon_l[0]) # Omit B and C axis = tuple(range(1, len(jnp.shape(inputs)) - 1)) # (B, 1, 1, C) nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True) # (B, W, H, C) xhat = inputs / jnp.sqrt(nu2 + epsilon) return gamma * xhat + beta
def policy_metrics(self): metrics = { 'policy_loss': self.policy_loss, 'advantage_mean': tl.Serial( self._policy_inputs_to_advantages(False), tl.Fn('Mean', lambda x: jnp.mean(x)) # pylint: disable=unnecessary-lambda ), 'advantage_std': tl.Serial( self._policy_inputs_to_advantages(False), tl.Fn('Std', lambda x: jnp.std(x)) # pylint: disable=unnecessary-lambda ) } metrics.update(awr_metrics( self._beta, preprocess_layer=self._policy_inputs_to_advantages(True))) return metrics
def Mean(axis=-1, keepdims=False): """Returns a layer that computes mean values using one tensor axis. `Mean` uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (`keepdims=True`), or that axis can be removed from the overall tensor (default `keepdims=False`), lowering the rank of the tensor by one. Args: axis: Axis along which values are grouped for computing a mean. keepdims: If `True`, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis. """ return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
def f(model_output, targets): # pylint: disable=invalid-name def non_nan(x): # pylint: disable=invalid-name return jnp.where(jnp.isnan(x), 0., x) beta2 = beta**2 predictions = jnp.argmax(model_output, axis=-1) n_categories = model_output.shape[-1] f_scores = jnp.empty(0) for k in range(initial_category_index, n_categories): n_correct = sum((predictions == k) & (targets == k)) precision = non_nan(n_correct / sum(predictions == k)) recall = non_nan(n_correct / sum(targets == k)) f_score = non_nan((beta2 + 1) * (precision * recall) / ((beta2 * precision) + recall)) f_scores = jnp.append(f_scores, f_score) return jnp.mean(f_scores)
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. opt_param_updates = self._for_n_devices( {'learning_rate': np.array(self.learning_rate)}) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. weights, slots, opt_params = opt_state (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( (weights, slots), self._step, opt_params, batch, self._model_state, self._rngs) self._opt_state = opt_state._replace(weights=weights, slots=slots) if self._should_log_now(): for name, value in stat.items(): scalar_value = np.mean(value) # On multiple devices, take the mean. self._train_sw.scalar('training/' + name, scalar_value, step=self._step) self._step += 1