def update(self, step, grads, weights, slots, opt_params): updates = [] learning_rate = opt_params['learning_rate'] beta1 = opt_params['beta1'] decay_rate = opt_params['decay_rate'] clipping_threshold = opt_params['clipping_threshold'] weight_decay_rate = opt_params['weight_decay_rate'] weight_decay_n_steps = opt_params['weight_decay_n_steps'] weight_decay_rate = jnp.where( weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it weight_decay_rate, (weight_decay_rate * jnp.maximum(weight_decay_n_steps - step, 0.0) / jnp.maximum(weight_decay_n_steps, 0.0))) epsilon1 = opt_params['epsilon1'] epsilon2 = opt_params['epsilon2'] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads if self._factored and len(weights.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = (decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1)) new_v_col = (decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2)) updates.extend([new_v_row, new_v_col]) row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) row_factor = (row_mean / (new_v_row + epsilon1))**0.5 col_factor = (new_v_col + epsilon1)**-0.5 y = (grads * jnp.expand_dims(row_factor, axis=-1) * jnp.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v + epsilon1)**-0.5 if self._do_clipping: clipping_denom = (jnp.maximum( 1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_weights = (1 - weight_decay_rate) * weights - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_weights.astype(weights.dtype), updates
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Returns `log N(x | mu, eye(diag_sigma))`. Args: x: <tbd> mu: <tbd> diag_sigma: <tbd> """ a = mu.shape[-1] * jnp.log(2 * jnp.pi) b = jnp.sum(jnp.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Returns `log N(x | mu, sigma)`. Args: x: <tbd> mu: <tbd> sigma: <tbd> """ a = mu.shape[-1] * jnp.log(2 * jnp.pi) _, b = jnp.linalg.slogdet(sigma) y = jnp.linalg.solve(sigma, x - mu) y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] if self._mode != 'train' or self._start_from_zero_prob >= 1.0: px = self.weights[:, :symbol_size, :] else: rng1, rng2 = fastmath.random.split(self.rng, 2) start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) start_from_zero = fastmath.random.uniform( rng2, (), jnp.float32, 0, 1) start = jnp.where(start_from_zero < self._start_from_zero_prob, jnp.zeros((), dtype=jnp.int32), start) px = fastmath.dynamic_slice_in_dim(self.weights, start, symbol_size, axis=1) if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( fastmath.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] res = inputs + jnp.stack(emb, 0) return res
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def favor(query, key, value, mask): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer mask_batch_1_length = jnp.reshape( mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32) mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1)) key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention, mask
def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention
def f(input): return jnp.expand_dims(input[:, 0, :], 1)