Esempio n. 1
0
    def update(self, step, grads, weights, slots, opt_params):
        updates = []
        learning_rate = opt_params['learning_rate']
        beta1 = opt_params['beta1']
        decay_rate = opt_params['decay_rate']
        clipping_threshold = opt_params['clipping_threshold']
        weight_decay_rate = opt_params['weight_decay_rate']
        weight_decay_n_steps = opt_params['weight_decay_n_steps']
        weight_decay_rate = jnp.where(
            weight_decay_n_steps <
            1,  # if weight_decay_n_steps == 0, ignore it
            weight_decay_rate,
            (weight_decay_rate *
             jnp.maximum(weight_decay_n_steps - step, 0.0) /
             jnp.maximum(weight_decay_n_steps, 0.0)))
        epsilon1 = opt_params['epsilon1']
        epsilon2 = opt_params['epsilon2']
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)),
                                        epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads
        if self._factored and len(weights.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = (decay_rate * v_row +
                         mixing_rate * jnp.mean(grads_sqr, axis=-1))
            new_v_col = (decay_rate * v_col +
                         mixing_rate * jnp.mean(grads_sqr, axis=-2))
            updates.extend([new_v_row, new_v_col])
            row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (row_mean / (new_v_row + epsilon1))**0.5
            col_factor = (new_v_col + epsilon1)**-0.5
            y = (grads * jnp.expand_dims(row_factor, axis=-1) *
                 jnp.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v + epsilon1)**-0.5

        if self._do_clipping:
            clipping_denom = (jnp.maximum(
                1.0,
                jnp.sqrt(jnp.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_weights = (1 - weight_decay_rate) * weights - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_weights.astype(weights.dtype), updates
Esempio n. 2
0
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
    """Returns `log N(x | mu, eye(diag_sigma))`.

  Args:
    x: <tbd>
    mu: <tbd>
    diag_sigma: <tbd>
  """
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    b = jnp.sum(jnp.log(diag_sigma), axis=-1)
    y = x - mu / diag_sigma
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Esempio n. 3
0
def log_gaussian_pdf(x, mu, sigma):  # pylint: disable=invalid-name
    """Returns `log N(x | mu, sigma)`.

  Args:
    x: <tbd>
    mu: <tbd>
    sigma: <tbd>
  """
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    _, b = jnp.linalg.slogdet(sigma)
    y = jnp.linalg.solve(sigma, x - mu)
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Esempio n. 4
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = self.weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(self.weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        fastmath.dynamic_slice_in_dim(self.weights[0],
                                                      state[i],
                                                      inputs.shape[1],
                                                      axis=0))
                self.state = state + inputs.shape[1]
                res = inputs + jnp.stack(emb, 0)
                return res
Esempio n. 5
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if fastmath.is_backend(fastmath.Backend.JAX):
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
Esempio n. 6
0
  def favor(query, key, value, mask):
    query_prime = relu(query) + numerical_stabilizer
    key_prime = relu(key) + numerical_stabilizer
    mask_batch_1_length = jnp.reshape(
        mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32)
    mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1))
    key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1])

    w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0),
                                jnp.moveaxis(key_prime, 1, 0),
                                jnp.moveaxis(value, 1, 0))
    r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0),
                                  jnp.moveaxis(key_prime, 1, 0))
    w = jnp.moveaxis(w, 0, 1)
    r = jnp.moveaxis(r, 0, 1)
    r = jnp.reciprocal(r)
    r = jnp.expand_dims(r, len(r.shape))
    renormalized_attention = w * r
    return renormalized_attention, mask
Esempio n. 7
0
    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            jnp.moveaxis(query_prime, 1, 0),
                            jnp.moveaxis(key_prime, 1, 0),
                            jnp.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              jnp.moveaxis(query_prime, 1, 0),
                              jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention
Esempio n. 8
0
 def f(input):
     return jnp.expand_dims(input[:, 0, :], 1)