Пример #1
0
def log_gaussian_pdf(x, mu, sigma):  # pylint: disable=invalid-name
    """Compute log N(x | mu, sigma)."""
    a = mu.shape[-1] * np.log(2 * np.pi)
    _, b = np.linalg.slogdet(sigma)
    y = np.linalg.solve(sigma, x - mu)
    y = np.expand_dims(y, axis=-1)
    xm = np.expand_dims(x - mu, axis=-2)
    c = np.matmul(xm, y)
    c = np.squeeze(np.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Пример #2
0
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
    """Compute log N(x | mu, eye(diag_sigma))."""
    a = mu.shape[-1] * np.log(2 * np.pi)
    b = np.sum(np.log(diag_sigma), axis=-1)
    y = x - mu / diag_sigma
    y = np.expand_dims(y, axis=-1)
    xm = np.expand_dims(x - mu, axis=-2)
    c = np.matmul(xm, y)
    c = np.squeeze(np.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Пример #3
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
Пример #4
0
def make_target_mask(target, pad=0):
    """Create an attention mask to hide padding and future words."""
    target_mask = (target != pad)[:, np.newaxis, :]
    target_dtype = target_mask.dtype
    target_mask = ((target_mask
                    & stax.causal_mask(target.shape[-1])).astype(target_dtype))
    return np.expand_dims(target_mask, axis=1)
Пример #5
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        (learning_rate, beta1, decay_rate, clipping_threshold,
         weight_decay_rate, epsilon1, epsilon2) = opt_params
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        return new_params, updates
Пример #6
0
def MakeTargetMask(target, pad=0):
    """Create an attention mask to hide padding and future words."""
    target_mask = (target != pad)[:, np.newaxis, :]
    target_dtype = target_mask.dtype
    causal_mask = onp.tril(onp.ones((1, target.shape[-1], target.shape[-1]),
                                    dtype=target_dtype),
                           k=0)
    target_mask = target_mask & causal_mask
    return np.expand_dims(target_mask, axis=1)
Пример #7
0
 def forward(self, inputs, params=(), state=(), **kwargs):
   if self._mode in ('train', 'eval'):
     x = inputs
     symbol_size = np.shape(x)[1]
     return (x + params[:, :symbol_size, :], state)
   else:
     assert self._mode == 'predict'
     # Fast inference: return consectutive elements of the encoding sequence,
     # storing the index in state.
     return (inputs + np.expand_dims(params[:, state, :], 1), state + 1)
Пример #8
0
  def update(self, i, g, x, state):
    updates = []
    decay_rate = self._decay_rate(i)
    update_scale = self._step_size(i)
    if self._multiply_by_parameter_scale:
      update_scale *= np.maximum(np.sqrt(np.mean(x * x)), self._epsilon2)
    mixing_rate = 1.0 - decay_rate

    g_sqr = g * g + self._epsilon1
    if self._factored and len(x.shape) >= 2:
      v_row = state.pop(0)
      v_col = state.pop(0)
      new_v_row = decay_rate * v_row + mixing_rate * np.mean(g_sqr, axis=-1)
      new_v_col = decay_rate * v_col + mixing_rate * np.mean(g_sqr, axis=-2)
      updates.extend([new_v_row, new_v_col])
      row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
      row_factor = (new_v_row / row_col_mean)**-0.5
      col_factor = (new_v_col)**-0.5
      y = (
          g * np.expand_dims(row_factor, axis=-1) *
          np.expand_dims(col_factor, axis=-2))
    else:
      v = state.pop(0)
      new_v = decay_rate * v + mixing_rate * g_sqr
      updates.append(new_v)
      y = g * (new_v)**-0.5

    if self._clipping_threshold is not None:
      clipping_denom = (
          np.maximum(1.0,
                     np.sqrt(np.mean(y * y)) / self._clipping_threshold))
      y /= clipping_denom

    subtrahend = update_scale * y
    if self._beta1:
      m = state.pop(0)
      new_m = self._beta1 * m + (1.0 - self._beta1) * subtrahend
      subtrahend = new_m
      updates.append(new_m)

    new_x = x - subtrahend
    return new_x, updates