def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Compute log N(x | mu, sigma).""" a = mu.shape[-1] * jnp.log(2 * jnp.pi) _, b = jnp.linalg.slogdet(sigma) y = jnp.linalg.solve(sigma, x - mu) y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Compute log N(x | mu, eye(diag_sigma)).""" a = mu.shape[-1] * jnp.log(2 * jnp.pi) b = jnp.sum(jnp.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def update(self, step, grads, weights, slots, opt_params): updates = [] learning_rate = opt_params['learning_rate'] beta1 = opt_params['beta1'] decay_rate = opt_params['decay_rate'] clipping_threshold = opt_params['clipping_threshold'] weight_decay_rate = opt_params['weight_decay_rate'] epsilon1 = opt_params['epsilon1'] epsilon2 = opt_params['epsilon2'] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(weights * weights)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(weights.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_weights = (1 - weight_decay_rate) * weights - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_weights.astype(weights.dtype), updates
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( x, np.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. return (inputs + np.expand_dims(weights[0, state, :], 1), state + 1)
def combined_loss(new_weights, observations, actions, target_values, advantage_weights, policy_and_value_net_apply, state=None, rng=None): """Returns the loss components.""" # reshape as (batch, 1, *obs_shape) - this is because that is the signature # demanded by `policy_and_value_net_apply` observations = np.expand_dims(observations, axis=1) (log_probab_actions_new, value_predictions_new) = ( policy_and_value_net_apply( observations, weights=new_weights, state=state, rng=rng)) critic_loss_val, intermediate_state = critic_loss( observations, target_values, value_predictions_new, state=state) actor_loss_val, final_state = actor_loss( actions, advantage_weights, log_probab_actions_new, state=intermediate_state) entropy_val = entropy(log_probab_actions_new) return critic_loss_val, actor_loss_val, entropy_val, final_state
def forward_with_state(self, inputs, weights, state, rng): if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. if inputs.shape[1] == 1: return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1) else: emb = [] for i in range(inputs.shape[0]): emb.append(jax.lax.dynamic_slice_in_dim( weights[0], state[i], inputs.shape[1], axis=0)) return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
def combined_loss(new_weights, observations, actions, target_values, advantage_weights, policy_and_value_net_apply, action_space, state=None, rng=None): """Returns the loss components.""" # TODO(afrozm): This is where we need to eventually feed the earlier # observations than this observation, currently the replay buffer just gives # the observation as is, for transformer like policies, we should also get # all the earlier observations as well, and then the extra dimension will # just be time. For now we reshape as (batch, 1, *obs_shape). observations = jnp.expand_dims(observations, axis=1) (log_probab_actions_new, value_predictions_new, state_new, unused_rng_new) = (policy_based_utils.run_policy_all_timesteps( policy_and_value_net_apply, observations, new_weights, state, rng, action_space, )) critic_loss_val, intermediate_state = critic_loss(observations, target_values, value_predictions_new, state=state_new) actor_loss_val, final_state = actor_loss(actions, advantage_weights, log_probab_actions_new, state=intermediate_state) entropy_val = entropy(log_probab_actions_new) return critic_loss_val, actor_loss_val, entropy_val, final_state