Example #1
0
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
  """Compute log N(x | mu, eye(diag_sigma))."""
  a = mu.shape[-1] * np.log(2 * np.pi)
  b = np.sum(np.log(diag_sigma), axis=-1)
  y = x - mu / diag_sigma
  y = np.expand_dims(y, axis=-1)
  xm = np.expand_dims(x - mu, axis=-2)
  c = np.matmul(xm, y)
  c = np.squeeze(np.squeeze(c, axis=-1), axis=-1)
  return -0.5 * (a + b + c)
Example #2
0
def log_gaussian_pdf(x, mu, sigma):  # pylint: disable=invalid-name
  """Compute log N(x | mu, sigma)."""
  a = mu.shape[-1] * np.log(2 * np.pi)
  _, b = np.linalg.slogdet(sigma)
  y = np.linalg.solve(sigma, x - mu)
  y = np.expand_dims(y, axis=-1)
  xm = np.expand_dims(x - mu, axis=-2)
  c = np.matmul(xm, y)
  c = np.squeeze(np.squeeze(c, axis=-1), axis=-1)
  return -0.5 * (a + b + c)
Example #3
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
Example #4
0
 def forward_with_state(self,
                        inputs,
                        weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     if self._mode in ('train', 'eval'):
         x = inputs
         symbol_size = np.shape(x)[1]
         px = weights[:, :symbol_size, :]
         if self._dropout == 0:
             return (x + px, state)
         else:
             noise_shape = list(px.shape)
             for dim in self._dropout_broadcast_dims:
                 noise_shape[dim] = 1
             keep_prob = 1.0 - self._dropout
             if backend.get_name() == 'jax':
                 keep_prob = jax.lax.tie_in(
                     x, np.full((), keep_prob, dtype=x.dtype))
             keep = backend.random.bernoulli(rng, keep_prob,
                                             tuple(noise_shape))
             multiplier = keep.astype(x.dtype) / keep_prob
             return (x + px * multiplier, state)
     else:
         assert self._mode == 'predict'
         assert self._dropout == 0
         # State in this class is only used for fast inference. In that case,
         # the model is called with consecutive elements position-by-position.
         # This positional encoding layer needs to store the index of the current
         # position then and increment it on each call -- that's how state is used
         # and updated below.
         return (inputs + np.expand_dims(weights[:, state, :], 1),
                 state + 1)