def update(self, step, grads, weights, slots, opt_params): updates = [] learning_rate = opt_params['learning_rate'] beta1 = opt_params['beta1'] decay_rate = opt_params['decay_rate'] clipping_threshold = opt_params['clipping_threshold'] weight_decay_rate = opt_params['weight_decay_rate'] weight_decay_n_steps = opt_params['weight_decay_n_steps'] weight_decay_rate = jnp.where( weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it weight_decay_rate, (weight_decay_rate * jnp.maximum(weight_decay_n_steps - step, 0.0) / jnp.maximum(weight_decay_n_steps, 0.0))) epsilon1 = opt_params['epsilon1'] epsilon2 = opt_params['epsilon2'] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads if self._factored and len(weights.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = (decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1)) new_v_col = (decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2)) updates.extend([new_v_row, new_v_col]) row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) row_factor = (row_mean / (new_v_row + epsilon1))**0.5 col_factor = (new_v_col + epsilon1)**-0.5 y = (grads * jnp.expand_dims(row_factor, axis=-1) * jnp.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v + epsilon1)**-0.5 if self._do_clipping: clipping_denom = (jnp.maximum( 1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_weights = (1 - weight_decay_rate) * weights - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_weights.astype(weights.dtype), updates
def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) # TODO(henrykm): return float(jnp.max(minimum, ret)) would be # better but causes TypeError: 'numpy.float64' object cannot # be interpreted as an integer if ret <= minimum: return minimum return ret
def forward(self, inputs): """Executes this layer as part of a forward pass through the model. Args: inputs: Tensor. Returns: Tensor of same shape and dtype as the input. """ threshold = self.weights return jnp.maximum(inputs, threshold)
def HardTanh(): r"""Returns a layer that computes a linear approximation to `Tanh`. .. math:: f(x) = \left\{ \begin{array}{cl} -1 & \text{if}\ x \leq 0, \\ x & \text{if}\ -1 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right. """ return Fn('HardTanh', lambda x: jnp.maximum(-1, jnp.minimum(1, x)))
def HardSigmoid(): r"""Returns a layer that computes a linear approximation to `Sigmoid`. .. math:: f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{if}\ 0 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right. """ return Fn('HardSigmoid', lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x))))
def TripletLossFn(v1, v2, margin=0.25): """Custom Loss function. Args: v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1. v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2. margin (float, optional): Desired margin. Defaults to 0.25. Returns: jax.interpreters.xla.DeviceArray: Triplet Loss. """ ### START CODE HERE (Replace instances of 'None' with your code) ### # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument) scores = fastnp.dot(v1, fastnp.transpose(v2)) # pairwise cosine sim # calculate new batch size batch_size = len(scores) # use fastnp to grab all postive `diagonal` entries in `scores` positive = fastnp.diagonal(scores) # the positive ones (duplicates) # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores` negative_without_positive = scores - fastnp.eye(batch_size) # take the row by row `max` of `negative_without_positive`. # Hint: negative_without_positive.max(axis = [?]) closest_negative = negative_without_positive.max(axis=[1]) # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores` negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)` mean_negative = fastnp.sum(negative_zero_on_duplicate, axis=1) / (batch_size - 1) # compute `fastnp.maximum` among 0.0 and `A` # A = subtract `positive` from `margin` and add `closest_negative` triplet_loss1 = fastnp.maximum((margin - positive + closest_negative), 0.0) # compute `fastnp.maximum` among 0.0 and `B` # B = subtract `positive` from `margin` and add `mean_negative` triplet_loss2 = fastnp.maximum((margin - positive + mean_negative), 0.0) # add the two losses together and take the `fastnp.mean` of it triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2) ### END CODE HERE ### return triplet_loss
def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) return float(ret)
def ParametricRelu(a=1.): r"""Returns a layer that computes a ReLU function with the given slope. .. math:: f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ ax & \text{otherwise}. \end{array} \right. Args: a: Slope of line for positive inputs. """ return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x)))
def _category_cross_entropy( # pylint: disable=invalid-name model_output, targets, label_smoothing, cutoff): """Computes category cross entropy with label smoothing.""" n_categories = model_output.shape[-1] target_distributions = core.one_hot(targets, n_categories) if label_smoothing: if label_smoothing < 0. or label_smoothing > 1.: raise ValueError( f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.' ) target_distributions *= (1. - label_smoothing) target_distributions += label_smoothing / n_categories model_log_distributions = core.log_softmax(model_output) cross_ent = -jnp.sum(target_distributions * model_log_distributions, axis=-1) if cutoff > 0.0: return jnp.maximum(cross_ent, cutoff) - cutoff else: return cross_ent
def test_forward(self): layer = tl.Fn( 'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2) x0 = np.array([1, 2, 3, 4, 5]) x1 = np.array([10, 20, 30, 40, 50]) y0, y1 = layer((x0, x1)) self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55]) self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50]) y2, y3 = layer.forward((x0, x1)) self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55]) self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50]) (y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE, None) self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55]) self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50]) self.assertEqual(state, tl.EMPTY_STATE)
def forward(self, xs): self._validate_forward_inputs(xs) (step, layers_state) = self.state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(self.rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. self.state = (step + 1, layers_state) return xs # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers weights = self.weights if n_layers != 1 and len(weights) != n_layers: raise ValueError('number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError('length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using fastmath.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) f_step = step.astype(jnp.float32) warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if fastmath.device_count() > 1: rng0 = fastmath.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) def CondF(t): return layer.pure_fn(t[0], t[1], t[2], t[3]) # pylint: disable=cell-var-from-loop def PassF(t): return t[0], t[2] outputs, s = fastmath.cond( fastmath.lt(cur_layer_idx, n_forward_layers), CondF, PassF, (inputs, p, s, rng) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 self.state = (step + 1, new_state) return stack