def loss_fun(x, step):
     del step
     logits = jnp.squeeze(predict_fun(x, features))
     onehot_targets = utils.one_hot(targets, num_classes)
     data_loss = -jnp.mean(jnp.sum(logits * onehot_targets, axis=1))
     reg_loss = l2_pen * utils.norm(x)
     return data_loss + reg_loss
def test_array_norm():
  """Tests computing the l2 norm of a single array."""

  # Generate a single array to compute the norm of.
  n = 10
  l2_norm = utils.norm(jnp.ones(n))
  assert np.allclose(jnp.sqrt(n), l2_norm)
Ejemplo n.º 3
0
 def loss_fun(x, step):
     del step
     logits = predict_fun(x, features)
     logits -= logsumexp(logits, axis=1, keepdims=True)
     onehot_targets = utils.one_hot(targets, num_classes)
     data_loss = -jnp.mean(jnp.sum(logits * onehot_targets, axis=1))
     reg_loss = l2_pen * utils.norm(x)
     return data_loss + reg_loss
def test_dict_norm():
  """Tests the (vectorized) l2 norm of a pytree."""

  # Generate a random pytree (in this case, a dict).
  rs = np.random.RandomState(0)
  pytree = {
      'x': rs.randn(10,),
      'y': rs.randn(3, 5),
  }

  # Test the norm of the vectorized pytree.
  vec = np.hstack((pytree['x'], pytree['y'].ravel()))
  assert np.allclose(np.linalg.norm(vec), float(utils.norm(pytree)), atol=1e-3)
Ejemplo n.º 5
0
 def loss_fun(x, step):
     del step
     logits = jnp.squeeze(predict_fun(x, features))
     data_loss = jnp.mean(jnp.log1p(jnp.exp(logits)) - targets * logits)
     reg_loss = l2_pen * utils.norm(x)
     return data_loss + reg_loss