def update(self, step, grads, params, slots, opt_params): updates = [] learning_rate = opt_params["learning_rate"] beta1 = opt_params["beta1"] decay_rate = opt_params["decay_rate"] clipping_threshold = opt_params["clipping_threshold"] weight_decay_rate = opt_params["weight_decay_rate"] epsilon1 = opt_params["epsilon1"] epsilon2 = opt_params["epsilon2"] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_params.astype(params.dtype), updates
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = backend.combine_devices(model_predict( backend.reshape_by_device(x, n_devices), weights, state, np.stack(jax_random.split(rng, n_devices)))) return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
def forward(self, inputs, weights): gamma, beta, epsilon_l = weights epsilon = self._init_epsilon if epsilon_l is not base.EMPTY_WEIGHTS: epsilon += np.abs(epsilon_l[0]) # Omit B and C axis = tuple(range(1, len(np.shape(inputs)) - 1)) # (B, 1, 1, C) nu2 = np.mean(inputs**2, axis=axis, keepdims=True) # (B, W, H, C) xhat = inputs / np.sqrt(nu2 + epsilon) return gamma * xhat + beta
def make_unit_length(self, x, epsilon=1e-6): variance = np.mean(x**2, axis=-1, keepdims=True) norm_inputs = x / np.sqrt(variance + epsilon) return norm_inputs
def Mean(x, axis=-1, keepdims=False, **unused_kwargs): return np.mean(x, axis=axis, keepdims=keepdims)
def _fast_mean_and_variance(self, x): mean = np.mean(x, self._axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, self._axis, keepdims=True) variance = m1 - mean**2 return mean, variance
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs): # pylint: disable=invalid-name (scale, bias) = weights mean = np.mean(x, axis=-1, keepdims=True) variance = np.mean((x - mean)**2, axis=-1, keepdims=True) norm_inputs = (x - mean) / np.sqrt(variance + epsilon) return norm_inputs * scale + bias