def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Compute log N(x | mu, sigma).""" a = mu.shape[-1] * np.log(2 * np.pi) _, b = np.linalg.slogdet(sigma) y = np.linalg.solve(sigma, x - mu) y = np.expand_dims(y, axis=-1) xm = np.expand_dims(x - mu, axis=-2) c = np.matmul(xm, y) c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Compute log N(x | mu, eye(diag_sigma)).""" a = mu.shape[-1] * np.log(2 * np.pi) b = np.sum(np.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = np.expand_dims(y, axis=-1) xm = np.expand_dims(x - mu, axis=-2) c = np.matmul(xm, y) c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def update(self, step, grads, params, slots, opt_params): updates = [] learning_rate = opt_params["learning_rate"] beta1 = opt_params["beta1"] decay_rate = opt_params["decay_rate"] clipping_threshold = opt_params["clipping_threshold"] weight_decay_rate = opt_params["weight_decay_rate"] epsilon1 = opt_params["epsilon1"] epsilon2 = opt_params["epsilon2"] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_params.astype(params.dtype), updates
def make_target_mask(target, pad=0): """Create an attention mask to hide padding and future words.""" target_mask = (target != pad)[:, np.newaxis, :] target_dtype = target_mask.dtype target_mask = ((target_mask & stax.causal_mask(target.shape[-1])).astype(target_dtype)) return np.expand_dims(target_mask, axis=1)
def update(self, step, grads, params, slots, opt_params): updates = [] (learning_rate, beta1, decay_rate, clipping_threshold, weight_decay_rate, epsilon1, epsilon2) = opt_params decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend return new_params, updates
def MakeTargetMask(target, pad=0): """Create an attention mask to hide padding and future words.""" target_mask = (target != pad)[:, np.newaxis, :] target_dtype = target_mask.dtype causal_mask = onp.tril(onp.ones((1, target.shape[-1], target.shape[-1]), dtype=target_dtype), k=0) target_mask = target_mask & causal_mask return np.expand_dims(target_mask, axis=1)
def forward(self, inputs, params=(), state=(), **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] return (x + params[:, :symbol_size, :], state) else: assert self._mode == 'predict' # Fast inference: return consectutive elements of the encoding sequence, # storing the index in state. return (inputs + np.expand_dims(params[:, state, :], 1), state + 1)
def update(self, i, g, x, state): updates = [] decay_rate = self._decay_rate(i) update_scale = self._step_size(i) if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(x * x)), self._epsilon2) mixing_rate = 1.0 - decay_rate g_sqr = g * g + self._epsilon1 if self._factored and len(x.shape) >= 2: v_row = state.pop(0) v_col = state.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(g_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(g_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = ( g * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = state.pop(0) new_v = decay_rate * v + mixing_rate * g_sqr updates.append(new_v) y = g * (new_v)**-0.5 if self._clipping_threshold is not None: clipping_denom = ( np.maximum(1.0, np.sqrt(np.mean(y * y)) / self._clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._beta1: m = state.pop(0) new_m = self._beta1 * m + (1.0 - self._beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_x = x - subtrahend return new_x, updates