示例#1
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = backend.combine_devices(model_predict(
       backend.reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(jax_random.split(rng, n_devices))))
   return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
示例#3
0
    def forward(self, inputs, weights):
        gamma, beta, epsilon_l = weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += np.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(np.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = np.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / np.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
示例#4
0
 def make_unit_length(self, x, epsilon=1e-6):
     variance = np.mean(x**2, axis=-1, keepdims=True)
     norm_inputs = x / np.sqrt(variance + epsilon)
     return norm_inputs
示例#5
0
文件: core.py 项目: wangleiphy/trax
def Mean(x, axis=-1, keepdims=False, **unused_kwargs):
  return np.mean(x, axis=axis, keepdims=keepdims)
示例#6
0
 def _fast_mean_and_variance(self, x):
     mean = np.mean(x, self._axis, keepdims=True)
     # Fast but less numerically-stable variance calculation than np.var.
     m1 = np.mean(x**2, self._axis, keepdims=True)
     variance = m1 - mean**2
     return mean, variance
示例#7
0
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs):  # pylint: disable=invalid-name
    (scale, bias) = weights
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
    norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
    return norm_inputs * scale + bias