def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Compute log N(x | mu, eye(diag_sigma)).""" a = mu.shape[-1] * np.log(2 * np.pi) b = np.sum(np.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = np.expand_dims(y, axis=-1) xm = np.expand_dims(x - mu, axis=-2) c = np.matmul(xm, y) c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Compute log N(x | mu, sigma).""" a = mu.shape[-1] * np.log(2 * np.pi) _, b = np.linalg.slogdet(sigma) y = np.linalg.solve(sigma, x - mu) y = np.expand_dims(y, axis=-1) xm = np.expand_dims(x - mu, axis=-2) c = np.matmul(xm, y) c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def update(self, step, grads, params, slots, opt_params): updates = [] learning_rate = opt_params["learning_rate"] beta1 = opt_params["beta1"] decay_rate = opt_params["decay_rate"] clipping_threshold = opt_params["clipping_threshold"] weight_decay_rate = opt_params["weight_decay_rate"] epsilon1 = opt_params["epsilon1"] epsilon2 = opt_params["epsilon2"] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_params.astype(params.dtype), updates
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( x, np.full((), keep_prob, dtype=x.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. return (inputs + np.expand_dims(weights[:, state, :], 1), state + 1)