def update(self, step, grads, params, slots, opt_params): updates = [] learning_rate = opt_params["learning_rate"] beta1 = opt_params["beta1"] decay_rate = opt_params["decay_rate"] clipping_threshold = opt_params["clipping_threshold"] weight_decay_rate = opt_params["weight_decay_rate"] epsilon1 = opt_params["epsilon1"] epsilon2 = opt_params["epsilon2"] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_params.astype(params.dtype), updates
def Softmax5Branches(x_list, n_branches=2, **unused_kwargs): """Softmax xs. The input xs is a list of embeddings and weights of the form w_1 e_1 .... w_n e_n (followed by optional rest that is preserved). Args: x_list: the input weights and embeddings. n_branches: what part of the list to use. Returns: softmax(w) * e for the joint weights w and embeddings e. """ assert n_branches == 5 softmax_activations = [x_list[2 * i] for i in range(n_branches)] max_sa = softmax_activations[0] for x in softmax_activations: max_sa = np.maximum(max_sa, x) softmax_activations = [x - max_sa for x in softmax_activations] softmax_activations = [np.exp(x) for x in softmax_activations] sum_sa = sum(softmax_activations) softmax_activations = [x / sum_sa for x in softmax_activations] res = sum([ x_list[2 * i + 1] * softmax_activations[i] for i in range(n_branches) ]) return res
def learning_rate(step): # pylint: disable=invalid-name """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= np.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= np.sqrt(warmup_steps) ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = np.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= np.maximum( 0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) ret = np.asarray(ret, dtype=np.float32) return {'learning_rate': ret}
def learning_rate(step): # pylint: disable=invalid-name """Step to learning rate function.""" ret = 1.0 for name in factors: if name == "constant": ret *= constant elif name == "linear_warmup": ret *= np.minimum(1.0, step / warmup_steps) elif name == "rsqrt_decay": ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == "decay_every": ret *= (decay_factor**(step // steps_per_decay)) else: raise ValueError("Unknown factor %s." % name) ret = np.asarray(ret, dtype=np.float32) return {"learning_rate": ret}
def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" if self._share_qk: key = self.make_unit_length(key) dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. if self._share_qk: self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e5 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize. dots /= dots_sum # Re-normalize. out_slice = np.matmul(dots, value) return out_slice
def Softmax5Branches(x_list, **unused_kwargs): """Softmax qs. The input xs is a list of weights and embedded queries of the form w_1 ... w_n q_1 ... q_n. The q_1 ... q_n will be kept, result appended. Args: x_list: the input weights and embeddings. Returns: the weighted average of q_1 ... q_n according to softmax(w). """ n_branches = 5 softmax_activations = x_list[:n_branches] max_sa = softmax_activations[0] for x in softmax_activations: max_sa = np.maximum(max_sa, x) softmax_activations = [x - max_sa for x in softmax_activations] softmax_activations = [np.exp(x) for x in softmax_activations] sum_sa = sum(softmax_activations) softmax_activations = [x / sum_sa for x in softmax_activations] res = sum([x_list[i + n_branches] * softmax_activations[i] for i in range(n_branches)]) return res
def forward(self, inputs, weights): threshold = weights[0] return np.maximum(inputs, threshold)
def HardTanh(x, **unused_kwargs): """Linear approximation to tanh.""" return np.maximum(-1, np.minimum(1, x))
def HardSigmoid(x, **unused_kwargs): """Linear approximation to sigmoid.""" return np.maximum(0, np.minimum(1, (1 + x)))
def ParametricRelu(x, a=1., **unused_kwargs): return np.maximum(a * x, np.zeros_like(x))
def Relu(x, **unused_kwargs): return np.maximum(x, np.zeros_like(x))