예제 #1
0
        def sample_scan(params, tup, x):
            """ Perform single step update of the network """
            _, (update_W, update_U,
                update_b), (reset_W, reset_U,
                            reset_b), (out_W, out_U, out_b), (sm_W,
                                                              sm_b) = params
            hidden = tup[3]
            logP = tup[2]
            key = tup[0]
            inp = tup[1]

            update_gate = sigmoid(
                np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b)
            reset_gate = sigmoid(
                np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b)
            output_gate = np.tanh(
                np.dot(inp, out_W) +
                np.dot(np.multiply(reset_gate, hidden), out_U) + out_b)
            output = np.multiply(update_gate, hidden) + np.multiply(
                1 - update_gate, output_gate)
            hidden = output
            logits = np.dot(hidden, sm_W) + sm_b

            key, subkey = random.split(key)

            samples = random.categorical(
                subkey, logits, axis=1, shape=None)  # sampling the conditional
            samples = one_hot(
                samples, sm_b.shape[0])  # convert to one hot encoded vector
            log_P_new = np.sum(samples * log_softmax(logits), axis=1)
            log_P_new = log_P_new + logP  # update the value of the logP of the sample

            return (key, samples, log_P_new, output), samples
예제 #2
0
def mixture_logistic_log_pdf(x: Array, prior_logits: Array, means: Array,
                             scales: Array) -> Array:
    """
    Args:
        x (Array): input vector
            (D,)
        prior_logits (Array): prior logits to weight the components
            (D, K)
        means (Array): means per component per feature
            (D, K)
        scales (Array): scales per component per feature
            (D, K)
    Returns:
        log_pdf (Array) : log PDF for the mixture distribution
    """
    x = jnp.expand_dims(x, axis=-1)

    base_dist = Logistic(loc=means, scale=scales)

    # x = (x - means) / scales
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=-1)

    # calculate the log pdf, (D,K)->(D,K)
    log_pdfs = prior_logits + base_dist.log_prob(x)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=-1)

    return log_pdf
예제 #3
0
 def apply_fun(params, inputs):
     serial_params, vhead_params, phead_params = params
     serial_out = serial_apply(serial_params, inputs)
     vhead_out = vhead_apply(vhead_params, serial_out)
     phead_out = log_softmax(phead_apply(phead_params, serial_out))
     out = (vhead_out, phead_out)
     return out
예제 #4
0
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
    """Compute cross entropy and entropy for log probs and targets.
    Args:
     logits: [batch, length, num_classes] float array.
     targets: categorical targets [batch, length] int array.
     weights: None or array of shape [batch, length]
     label_smoothing: label smoothing constant, used to determine the on and off values.
    Returns:
      Tuple of scalar loss and batch normalizing factor.
    """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
        )

    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
    )
    soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant

    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()
    else:
        normalizing_factor = np.prod(targets.shape)

    return loss.sum(), normalizing_factor
예제 #5
0
def mixture_logistic_log_pdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_pdf (JaxArray) : log PDF for the mixture distribution
    """
    # n_components = prior_logits.shape[1]
    #

    # add component dimension, (D,)->(D,1)
    # will allow for broadcasting
    x_r = x.reshape(-1, 1)

    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    # print(x.shape, prior_logits.shape, )
    log_pdfs = prior_logits + logistic_log_pdf(x_r, means, scales)
    # print("Log PDFS:", log_pdfs.shape)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=1)

    return log_pdf
예제 #6
0
def mixture_logistic_cdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_cdf (JaxArray) : log CDF for the mixture distribution
    """
    # print(prior_logits.shape)
    # n_features, n_components = prior_logits
    x_r = x.reshape(-1, 1)
    #
    # x_r = np.tile(x, (n_features, n_components))
    # print(x.shape, x_r.shape)
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    log_cdfs = prior_logits + logistic_log_cdf(x_r, means, scales)

    # normalize distribution for components, (D,K)->(D,)
    log_cdf = logsumexp(log_cdfs, axis=1)

    return np.exp(log_cdf)
예제 #7
0
def structure_mixture_params(components) -> LogisticMixtureParams:
    unnormalized_weights = components[:, 2]
    probs = list(np.exp(nn.log_softmax(unnormalized_weights)))
    component_params = [
        LogisticParams(component[0], component[1]) for component in components
    ]
    return LogisticMixtureParams(components=component_params, probs=probs)
예제 #8
0
def mixture_gaussian_log_pdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        log_pdf (JaxArray) : log PDF for the mixture distribution
    """
    # n_components = prior_logits.shape[1]
    #
    # x_r = np.tile(x, (n_components))
    x_r = x.reshape(-1, 1)
    # normalize logit weights to 1, (D,K)->(D,K)
    prior_logits = log_softmax(prior_logits, axis=1)

    # calculate the log pdf, (D,K)->(D,K)
    log_pdfs = prior_logits + jax.scipy.stats.norm.logpdf(x_r, means, scales)

    # normalize distribution for components, (D,K)->(D,)
    log_pdf = logsumexp(log_pdfs, axis=1)

    return log_pdf
예제 #9
0
def main(unused_argv):
  # Load data and preprocess it.
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                          permute_train=True)

  # Build the network
  init_fn, f, _ = stax.serial(
      stax.Dense(512, 1., 0.05),
      stax.Erf(),
      stax.Dense(10, 1., 0.05))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = nt.linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(_LEARNING_RATE, 0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // _BATCH_SIZE

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, _BATCH_SIZE, _TRAIN_EPOCHS)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
예제 #10
0
def poisson_categorical_log_prior(length, rate):
    """Categorical prior populated with log probabilities of Poisson dist."""
    rate = jnp.array(rate, dtype=jnp.float32)
    values = jnp.expand_dims(jnp.arange(1, length + 1, dtype=jnp.float32), 0)
    log_prob_unnormalized = jax.lax.lgamma(
        jnp.log(rate) * values - rate - (values + 1))
    # TODO(tkipf): Length-sensitive normalization.
    return nn.log_softmax(log_prob_unnormalized, axis=1)  # Normalize.
예제 #11
0
    def apply_fun(params, x, adj, is_training=False, **kwargs):
        rng = kwargs.pop('rng', None)
        rngs = random.split(rng, len(attn_funs))

        for i, layer_fun in enumerate(attn_funs):
            x = layer_fun(params[i], x, adj, rng=rngs[i], is_training=is_training)
        
        return nn.log_softmax(x)
예제 #12
0
def mixture_logpdf_single(datum, components):
    component_scores = []
    unnormalized_weights = np.array([component[2] for component in components])
    weights = nn.log_softmax(unnormalized_weights)
    for component, weight in zip(components, weights):
        loc = component[0]
        scale = np.max([component[1], 0.01])  # Find a better solution?
        component_scores.append(logistic_logpdf(datum, loc, scale) + weight)
    return scipy.special.logsumexp(np.array(component_scores))
def predict(params, x):
    # per-example predictions
    activations = x
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = nn.softmax(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return nn.log_softmax(logits)
예제 #14
0
파일: models.py 프로젝트: gcucurull/jax-gcn
    def apply_fun(params, x, adj, is_training=False, **kwargs):
        rng = kwargs.pop('rng', None)
        k1, k2, k3, k4 = random.split(rng, 4)

        x = drop_fun(None, x, is_training=is_training, rng=k1)
        x = gc1_fun(params[0], x, adj, rng=k2)
        x = nn.relu(x)
        x = drop_fun(None, x, is_training=is_training, rng=k3)
        x = gc2_fun(params[1], x, adj, rng=k4)
        x = nn.log_softmax(x)
        return x
예제 #15
0
    def apply_fun(params, first_x, adj, is_training=False, **kwargs):
        rng = kwargs.pop('rng', None)
        k1, k2, k3, k4 = random.split(rng, 4)
        adj_1, adj_5 = adj

        x = gc1_fun(params[0], (first_x, first_x), adj_1, rng=k2, 
            is_training=is_training) 
        x = gc2_fun(params[1], (first_x, x), adj_5, activation=lambda x: x, 
            rng=k4, is_training=is_training)
        x = nn.log_softmax(x)
        return x
def kl_sample_gmm(key, q_mean_u, q_logvar_u, gmm_resps_c, gmm_p_mean_cxu,
                  gmm_p_logvar_cxu, varmin):
    """Sample the KL divergence between a gaussian and mixture of gaussians.
            KL(q||p) = E_q[log q(z) - log p(z))]

  In this case q is gaussian, p is gmm, which requires sampling. So we sample
  z ~ q(z) and then compute log q(z) - log p(z).

  Watch the numerics here, varmin needs to be tiny.

  Arguments:
    key: random.PRNGKey for random bits
    q_mean_u: mean from posterior gaussian distribution with dim u
    q_logvar_u: log variance from posterior gaussian distribution with dim u
    gmm_resps_c: np.array with shape c, responsibilities in the GMM, \pi in
      the above formula is softmax(gmm_resps_c).
    gmm_p_mean_cxu: np.array 2D array with shape mixture by dist dim, means in GMM
    gmm_p_logvar_cxu: "", log variances in GMM
    varmin: Minimum variance allowed (numerially useful).

  Returns:
    Single estimate of the KL divergence.
  """

    # Handle case of one gaussian in the mixture with closed form equations.
    if gmm_resps_c.shape[0] == 1:
        return np.sum(
            kl_gauss_gauss(q_mean_u, q_logvar_u, gmm_p_mean_cxu[0, :],
                           gmm_p_logvar_cxu[0, :], varmin))

    # Otherwise sample the KL
    ll = diag_gaussian_log_likelihood
    gmm_ll = gmm_diag_gaussian_log_likelihood
    sample = diag_gaussian_sample
    keys = random.split(key, 2)

    z_u = sample(keys[0], q_mean_u, q_logvar_u, varmin)
    logq_u = ll(z_u, q_mean_u, q_logvar_u, varmin)  # over multigauss dim

    assert varmin <= 1e-15, "Very small or you need to know what you are doing."
    llp_each_gaussian_cxu = gmm_ll(z_u, gmm_p_mean_cxu, gmm_p_logvar_cxu,
                                   varmin)
    log_normed_resps_cx1 = np.expand_dims(log_softmax(gmm_resps_c), axis=1)
    logp_u = logsumexp(llp_each_gaussian_cxu + log_normed_resps_cx1, axis=0)

    kl_estimate = np.sum(logq_u - logp_u, axis=0)
    return kl_estimate
예제 #17
0
파일: wavenet.py 프로젝트: tom-bird/jaxnet
def discretized_mix_logistic_loss(theta, y, num_class=256, log_scale_min=-7.):
    """
    Discretized mixture of logistic distributions loss
    :param theta: B x T x 3 * nr_mix
    :param y:  B x T x 1
    """
    theta_shape = theta.shape

    nr_mix = theta_shape[2] // 3

    # unpack parameters
    means = theta[:, :, :nr_mix]
    log_scales = np.maximum(theta[:, :, nr_mix:2 * nr_mix], log_scale_min)
    logit_probs = theta[:, :, nr_mix * 2:nr_mix * 3]

    # B x T x 1 => B x T x nr_mix
    y = np.broadcast_to(y, y.shape[:-1] + (nr_mix, ))

    centered_y = y - means
    inv_stdv = np.exp(-log_scales)
    plus_in = inv_stdv * (centered_y + 1. / (num_class - 1))
    cdf_plus = sigmoid(plus_in)
    min_in = inv_stdv * (centered_y - 1. / (num_class - 1))
    cdf_min = sigmoid(min_in)

    # log probability for edge case of 0 (before scaling):
    log_cdf_plus = plus_in - softplus(plus_in)
    # log probability for edge case of 255 (before scaling):
    log_one_minus_cdf_min = -softplus(min_in)

    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_y

    log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in)

    log_probs = np.where(
        y < -0.999, log_cdf_plus,
        np.where(
            y > 0.999, log_one_minus_cdf_min,
            np.where(cdf_delta > 1e-5, np.log(np.maximum(cdf_delta, 1e-12)),
                     log_pdf_mid - np.log((num_class - 1) / 2))))

    log_probs = log_probs + log_softmax(logit_probs)
    return -np.sum(logsumexp(log_probs, axis=-1), axis=-1)
예제 #18
0
        def apply_fun_scan(params, tup, inp):
            """ Perform single step update of the network """
            _, (update_W, update_U,
                update_b), (reset_W, reset_U,
                            reset_b), (out_W, out_U, out_b), (sm_W,
                                                              sm_b) = params
            hidden = tup[0]
            logP = tup[1]

            update_gate = sigmoid(
                np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b)
            reset_gate = sigmoid(
                np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b)
            output_gate = np.tanh(
                np.dot(inp, out_W) +
                np.dot(np.multiply(reset_gate, hidden), out_U) + out_b)
            output = np.multiply(update_gate, hidden) + np.multiply(
                1 - update_gate, output_gate)
            hidden = output
            logP = log_softmax(np.dot(hidden, sm_W) + sm_b)

            return (hidden, logP), (hidden, logP)
예제 #19
0
 def loss(observations, actions, rewards_to_go):
     logprobs = log_softmax(policy(observations))
     action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions]
     return -np.mean(action_logprobs * rewards_to_go, axis=0)
예제 #20
0
def cross_entropy(model: Model, inputs: Tensor, targets: Tensor) -> float:
    # log softmax
    predicted = nn.log_softmax(model(inputs))

    # negative log likelihood
    return -np.mean(np.sum(targets * predicted, axis=1))
예제 #21
0
def onnx_log_softmax(x, axis=-1):
    return log_softmax(x, axis)
예제 #22
0
def cross_entropy(outputs: DeviceArray, targets: DeviceArray) -> DeviceArray:
    probs = log_softmax(outputs)
    labels = _one_hot(targets, 10)
    loss = -jnp.mean(probs * labels)
    return loss
예제 #23
0
 def from_params(cls, fixed_params, opt_params, traceable=False):
     logps = nn.log_softmax(opt_params)
     return cls(logps, traceable=traceable)
예제 #24
0
 def criterion(logits, targets):
     return -jnp.mean(jnp.sum(log_softmax(logits) * targets, axis=1), axis=0)
예제 #25
0
def accuracy(params, batch, predict_fn):
    logits = predict_fn(params, batch["X"])
    logits = log_softmax(logits)
    return jnp.mean(jnp.argmax(logits, -1) == batch["y"])
예제 #26
0
def softmax_cross_entropy(logits, labels):
    one_hot_labels = one_hot(labels, logits.shape[-1])
    return -jnp.sum(log_softmax(logits) * one_hot_labels, axis=-1)
예제 #27
0
def hmm_viterbi_log(params, obs_seq, length=None):
    '''
    Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was
    in each possible state given the observations that were made at each time step, i.e.
    P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1
    It is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py

    Parameters
    ----------
    params : HMM
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observed states
    Returns
    -------
    * array(seq_len, n_states)
        Alpha values
    * array(seq_len, n_states)
        Beta values
    * array(seq_len, n_states)
        Marginal conditional probability
    * float
        The loglikelihood giving log(p(x|model))
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_dist, obs_dist, init_dist = params.trans_dist, params.obs_dist, params.init_dist

    trans_log_probs = log_softmax(trans_dist.logits)
    init_log_probs = log_softmax(init_dist.logits)

    n_states = obs_dist.batch_shape[0]

    first_log_prob = init_log_probs + obs_dist.log_prob(obs_seq[0])

    if seq_len == 1:
        return jnp.expand_dims(jnp.argmax(first_log_prob), axis=0)

    def viterbi_forward(prev_logp, t):
        obs_logp = obs_dist.log_prob(obs_seq[t])

        logp = jnp.where(
            t <= length,
            prev_logp[..., None] + trans_log_probs + obs_logp[..., None, :],
            -jnp.inf + jnp.zeros_like(trans_log_probs))

        max_logp_given_successor = jnp.where(t <= length, jnp.max(logp,
                                                                  axis=-2),
                                             prev_logp)
        most_likely_given_successor = jnp.where(t <= length,
                                                jnp.argmax(logp, axis=-2), -1)

        return max_logp_given_successor, most_likely_given_successor

    ts = jnp.arange(1, seq_len)
    final_log_prob, most_likely_sources = lax.scan(viterbi_forward,
                                                   first_log_prob, ts)

    most_likely_initial_given_successor = jnp.argmax(trans_log_probs +
                                                     first_log_prob,
                                                     axis=-2)

    most_likely_sources = jnp.concatenate([
        jnp.expand_dims(most_likely_initial_given_successor, axis=0),
        most_likely_sources
    ],
                                          axis=0)

    def viterbi_backward(state, t):
        state = jnp.where(
            t <= length,
            jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype(
                jnp.int64), state)
        most_likely = jnp.where(t <= length, state, -1)
        return state, most_likely

    final_state = jnp.argmax(final_log_prob)
    _, most_likely_path = lax.scan(viterbi_backward,
                                   final_state,
                                   ts,
                                   reverse=True)

    final_state = jnp.where(length == seq_len, final_state, -1)

    return jnp.append(most_likely_path, final_state)
예제 #28
0
 def from_params(cls, params, traceable=False):
     logps = nn.log_softmax(params)
     return cls(logps, traceable=traceable)
예제 #29
0
파일: main.py 프로젝트: justinchiu/rff-mrf
def inner(num_features, qk_dim, S, T, temp_sqrt):
    qk_key = jax.random.PRNGKey(111)
    key1, key2 = jax.random.split(qk_key)
    q = jax.random.normal(key1, (S, qk_dim)) / temp_sqrt
    #q = jax.random.normal(key1, (T, qk_dim)) / temp_sqrt
    k = jax.random.normal(key2, (T, qk_dim)) / temp_sqrt

    log_probs = log_softmax(q @ k.T)
    print(
        f"Total entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum()}"
    )
    print(
        f"Mean entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum(-1).mean()}"
    )
    st.write(
        f"Total entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum()}"
    )
    st.write(
        f"Mean entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum(-1).mean()}"
    )

    #"""
    # fit exp kernel
    def loss(q, k, dummy_proj, attn):
        logits = q @ k.T
        probs = softmax(logits)
        return fat.kl(attn, probs).mean()

    L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1)))

    key, key1 = jax.random.split(key_train_init)
    print(f"Softmax fit")
    title = f"KL softmax fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})"
    kl_ = report_train(
        q,
        k,
        lambda x: None,
        L_dL,
        num_features,
        key1,
        train_fn=fat.train,
        sample=False,
        title=title,
    )
    print(f"kl {kl_}")
    #"""
    #
    """
    # fit exp kernel fwd renorm
    def loss(q, k, scale, dummy_proj, attn):
        logits = renorm(q) @ renorm(k).T
        probs = softmax(logits)
        return fat.kl(attn, probs).mean()
    L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

    key, key1 = jax.random.split(key_train_init)
    print(f"Softmax fwd renorm fit")
    title = f"KL softmax fwd renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})"
    kl_ = report_train(
        q, k,
        lambda x: None, L_dL, num_features, key1,
        train_fn = fat.train_proj,
        sample=False, title=title,
        post_renorm=False,
    )
    print(f"kl {kl_}")
    """
    """
    # fit exp kernel post renorm
    def loss(q, k, scale, dummy_proj, attn):
        logits = q @ k.T
        probs = softmax(logits)
        return fat.kl(attn, probs).mean()
    L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

    key, key1 = jax.random.split(key_train_init)
    print(f"Softmax post renorm fit")
    title = f"KL softmax post renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})"
    kl_ = report_train(
        q, k,
        lambda x: None, L_dL, num_features, key1,
        train_fn = fat.train_proj,
        sample=False, title=title,
        post_renorm=True,
    )
    print(f"kl {kl_}")
    """
    #
    """
    def loss(q, k, scale, dummy_proj, attn):
        logits = renorm(q) @ renorm(k).T
        probs = softmax(3. * logits)
        return fat.kl(attn, probs).mean()
    L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

    key, key1 = jax.random.split(key_train_init)
    print(f"Softmax temp fwd renorm fit")
    title = f"KL softmax temp fwd renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})"
    kl_ = report_train(
        q, k,
        lambda x: None, L_dL, num_features, key1,
        train_fn = fat.train_proj,
        sample=False, title=title,
        post_renorm=False,
    )
    print(f"kl {kl_}")
    """
    """
    def loss(q, k, scale, dummy_proj, attn):
        logits = q @ k.T
        probs = softmax(3. * logits)
        return fat.kl(attn, probs).mean()
    L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

    key, key1 = jax.random.split(key_train_init)
    print(f"Softmax temp post renorm fit")
    title = f"KL softmax temp post renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})"
    kl_ = report_train(
        q, k,
        lambda x: None, L_dL, num_features, key1,
        train_fn = fat.train_proj,
        sample=False, title=title,
        post_renorm=True,
    )
    print(f"kl {kl_}")
    """

    def proj_fn(shape, key):
        sample_key, norm_key = jax.random.split(key)
        gaussian_sample = fat.random_projection(num_features, qk_dim,
                                                sample_key)
        print(gaussian_sample.shape)
        projection_matrix = fat.get_2d_array(gaussian_sample,
                                             norm_key,
                                             scaling=0)
        return projection_matrix

    proj_fn = functools.partial(proj_fn, (num_features, qk_dim))

    def proj_fn_gaus(shape, key):
        sample_key, norm_key = jax.random.split(key)
        gaussian_sample = fat.random_projection(num_features, qk_dim,
                                                sample_key)
        return gaussian_sample

    proj_fn_gaus = functools.partial(proj_fn_gaus, (num_features, qk_dim))

    def proj_fn_anti(shape, key):
        sample_key, norm_key = jax.random.split(key)
        gaussian_sample = fat.random_projection(num_features // 2, qk_dim,
                                                sample_key)
        projection_matrix = fat.get_2d_array(gaussian_sample,
                                             norm_key,
                                             scaling=0)
        return jnp.concatenate([projection_matrix, -projection_matrix], axis=0)

    proj_fn_anti = functools.partial(proj_fn_anti, (num_features, qk_dim))

    def proj_fn_reg(shape, key):
        sample_key, norm_key = jax.random.split(key)
        gaussian_sample = fat.random_projection(num_features, qk_dim,
                                                sample_key)
        projection_matrix = fat.get_2d_array(gaussian_sample,
                                             norm_key,
                                             scaling=1)
        return projection_matrix

    proj_fn_reg = functools.partial(proj_fn_reg, (num_features, qk_dim))

    def proj_fn_reg_small(shape, key):
        sample_key, norm_key = jax.random.split(key)
        gaussian_sample = fat.random_projection(num_features, qk_dim,
                                                sample_key)
        projection_matrix = fat.get_2d_array(gaussian_sample,
                                             norm_key,
                                             scaling=2)
        return projection_matrix

    proj_fn_reg_small = functools.partial(proj_fn_reg_small,
                                          (num_features, qk_dim))

    #for sample_key in [True, False]:
    for sample_key in [False]:
        #for proj_fn in [proj_fn, proj_fn_reg]:
        #for this_proj_fn in [proj_fn, proj_fn_anti]:
        for this_proj_fn in [proj_fn]:
            print(this_proj_fn)
            """
            def loss(q, k, scale, proj, attn_dist):
                ra, _ = fat.rff_attn(q, k, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth fit (Sample: {sample_key})")
            title = f"KL Orth fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(q, k, this_proj_fn, L_dL, num_features, key1,
                fat.train_proj, sample_key, title,
            )
            print(f"kl {kl_}")
            """

            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = renorm(q, axis=-1)
                kp = renorm(k, axis=-1)
                ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth projected L2 fit (Sample: {sample_key})")
            title = f"KL Orth projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                proj_fn,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")
            #"""
            """
            def loss(q, k, scale, proj, attn_dist):
                qp = renorm_stopgrad(q, axis=-1)
                kp = renorm_stopgrad(k, axis=-1)
                ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth projected L2 detach fit (Sample: {sample_key})")
            title = f"KL Orth projected L2 detach fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(q, k, proj_fn, L_dL, num_features, key1,
                fat.train_proj, sample_key, title,
            )
            print(f"kl {kl_}")
            """
            """
            def loss(q, k, scale, proj, attn_dist):
                ra, _ = fat.rff_attn(q, k, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth post projected L2 fit (Sample: {sample_key})")
            title = f"KL orth post projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(q, k, proj_fn, L_dL, num_features, key1,
                fat.train_proj, sample_key, title,
                post_renorm=True,
            )
            print(f"kl {kl_}")
            """

            #"""
            # learn scale
            def loss(q, k, scale, proj, attn_dist):
                qp = renorm(q, axis=-1)
                kp = renorm(k, axis=-1)
                ra, _ = fat.rff_attn(qp, kp, proj)
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth projected L2 fit proj (Sample: {sample_key})")
            title = f"KL Orth Projected L2 fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                proj_fn,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")
            #"""
            """
            def loss(q, k, scale, proj, attn_dist):
                qp = renorm(q, axis=-1)
                kp = renorm(k, axis=-1)
                ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth Projected L2 fit scale projfnreg (Sample: {sample_key})")
            title = f"KL Orth Projected L2 fit scale projfnreg (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q, k,
                #proj_fn,
                proj_fn_reg,
                L_dL,
                num_features, key1,
                fat.train_proj, sample_key, title,
            )
            print(f"kl {kl_}")
            """
            """
            def loss(q, k, scale, proj, attn_dist):
                qp = renorm(q, axis=-1)
                kp = renorm(k, axis=-1)
                ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(scale) * proj)
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))

            key, key1 = jax.random.split(key_train_init)
            print(f"Orth Projected L2 fit proj (Sample: {sample_key})")
            title = f"KL Orth Projected L2 fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q, k,
                proj_fn,
                #proj_fn_norm,
                L_dL, num_features, key1,
                fat.train_proj, sample_key, title,
            )
            print(f"kl {kl_}")
            """
            """
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.relu_rff_attn0(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            #print(f"Relu0 Projected L2 fit (Sample: {sample_key})")
            #title = f"KL Relu0 Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            print(f"Relu0 fit (Sample: {sample_key})")
            title = f"KL Relu0 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q, k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg,
                L_dL, num_features, key1,
                fat.train_proj, sample_key, title,
            )
            print(f"kl {kl_}")
            """

            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.exp_rff_attn(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            print(f"Exp fit (Sample: {sample_key})")
            title = f"KL Exp fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")

            #"""
            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.exp_rff_attn(qp, kp, proj)
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            print(f"Exp fit proj (Sample: {sample_key})")
            title = f"KL Exp fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")

            #"""
            #
            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.relu_rff_attn(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            #print(f"Relu Projected L2 fit (Sample: {sample_key})")
            #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            print(f"Relu fit (Sample: {sample_key})")
            title = f"KL Relu fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")

            #"""
            #
            #
            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.relu_rff_attn(qp, kp, proj)
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            print(f"Relu fit proj (Sample: {sample_key})")
            title = f"KL Relu fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")

            #"""
            #
            #
            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.relu_rff_attn(qp, kp, jax.lax.stop_gradient(proj))
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            #print(f"Relu Projected L2 fit (Sample: {sample_key})")
            #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            print(f"Relu fit small (Sample: {sample_key})")
            title = f"KL Relu fit small (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg_small,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")

            #"""
            #
            #
            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.relu_rff_attn(qp, kp, proj)
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            #print(f"Relu Projected L2 fit (Sample: {sample_key})")
            #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            print(f"Relu fit small proj (Sample: {sample_key})")
            title = f"KL Relu fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg_small,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")

            #"""
            #
            #"""
            def loss(q, k, scale, proj, attn_dist):
                qp = q
                kp = k
                ra, _ = fat.relu_rff_attn(qp, kp, scale * proj)
                return fat.kl(attn_dist, ra).mean()

            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            print(f"Relu fit small proj (Sample: {sample_key})")
            title = f"KL Relu fit small proj scale (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q,
                k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg_small,
                L_dL,
                num_features,
                key1,
                fat.train_proj,
                sample_key,
                title,
            )
            print(f"kl {kl_}")
            #"""
            """
            # bad
            def loss(q, k, scale, proj, attn_dist):
                qp = renorm(q)
                kp = renorm(k)
                ra, _ = fat.relu_rff_attn(qp, kp, proj)
                return fat.kl(attn_dist, ra).mean()
            L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3)))
            #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3)

            key, key1 = jax.random.split(key_train_init)
            #print(f"Relu Projected L2 fit (Sample: {sample_key})")
            #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            print(f"Relu projected fit proj (Sample: {sample_key})")
            title = f"KL Relu projected fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})"
            kl_ = report_train(
                q, k,
                #proj_fn,
                #proj_fn_gaus,
                proj_fn_reg,
                L_dL, num_features, key1,
                fat.train_proj, sample_key, title,
            )
            print(f"kl {kl_}")
            """
            """
예제 #30
0
 def logp(self, params, obs, act):
     logits = self._net_apply(params, obs)
     all_logps = nn.log_softmax(logits)
     return (hk.one_hot(act, self.act_dim) * all_logps).sum(-1)