Esempio n. 1
0
  def test_linear_no_bias(self):
    """Test hk_util.Linear() without bias term."""
    def model_fun(x):
      return hk_util.Linear(
          5,
          use_bias=False,
          w_init=hk.initializers.TruncatedNormal(1.0),
          w_regularizer=functools.partial(hk_util.h1_loss, axis=0),
          name='foo')(x)

    model = hk_util.transform(model_fun)
    x = jnp.array([-0.3, 0.5, 4.0], np.float32)
    params = model.init(jax.random.PRNGKey(0), x)
    y, penalties = model.apply(params, None, x)

    # Check params structure.
    self.assertIn('foo', params)
    self.assertIn('w', params['foo'])
    self.assertNotIn('b', params['foo'])
    self.assertEqual(params['foo']['w'].shape, (3, 5))

    # Check `y` output value.
    expected_y = jnp.dot(x, params['foo']['w'])
    np.testing.assert_allclose(y, expected_y, atol=1e-6)

    # Check penalties.
    expected_penalties = hk_util.h1_loss(params['foo']['w'], axis=0)
    np.testing.assert_allclose(penalties, expected_penalties, atol=1e-6)
Esempio n. 2
0
  def test_regularized_training(self):
    """Test that adding regularization penalty to the training loss works."""
    np.random.seed(0)
    # Set up the problem of recovering w given x and
    #   y = x . w + noise
    # with the a priori assumption that w is sparse. There are fewer examples
    # than dimensions (x is a wide matrix), so the problem is underdetermined
    # without the sparsity assumption.
    num_examples, num_dim = 8, 10
    x = np.random.randn(num_examples, num_dim).astype(np.float32)
    true_w = np.zeros((num_dim, 2), np.float32)
    true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0]
    true_w[[3, 5], 1] = [4.0, 5.0]
    y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2)

    # Get the least squares estimate for w. It isn't very accurate.
    least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0]
    least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w)

    # Get a better estimate by solving the L1 regularized problem
    #  argmin_w ||x . w - y||_2^2 + c ||w||_1.
    w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w)
    def model_fun(batch):
      x = batch['x']
      return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x)

    model = hk_util.transform(model_fun)

    def loss_fun(params, batch):
      """Training loss with L1 regularization penalty term."""
      y_predicted, penalties = model.apply(params, None, batch)
      return hk_util.l2_loss(y_predicted - batch['y']) + penalties

    batch = {'x': x, 'y': y}
    params = model.init(jax.random.PRNGKey(0), batch)
    optimizer = optax.chain(  # Gradient descent with decreasing learning rate.
        optax.trace(decay=0.0, nesterov=False),
        optax.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i)))
    opt_state = optimizer.init(params)

    @jax.jit
    def train_step(params, opt_state, batch):
      grads = jax.grad(loss_fun)(params, batch)
      updates, opt_state = optimizer.update(grads, opt_state)
      new_params = optax.apply_updates(params, updates)
      return new_params, opt_state

    for _ in range(1000):
      params, opt_state = train_step(params, opt_state, batch)

    l1_w = params['linear']['w']
    l1_w_error = hk_util.l2_loss(l1_w - true_w).item()

    # The L1-regularized estimate is much more accurate.
    self.assertGreater(least_squares_w_error, 4.0)
    self.assertLess(l1_w_error, 1.0)
Esempio n. 3
0
  def test_summarize_model(self):

    def model_fun(x):
      """A model with two submodules."""

      class Alpha(hk.Module):  # Alpha submodule.

        def __call__(self, x):
          return hk.Sequential([
              hk.Conv2D(8, (3, 3)), jax.nn.relu,
              hk.MaxPool((1, 2, 2, 1), (1, 2, 2, 1), 'VALID'),
              hk.Flatten(),
              hk.Linear(3, with_bias=False)
          ])(x)

      class Beta(hk.Module):  # Beta submodule.

        def __call__(self, x):
          return hk.Sequential([hk.Flatten(), hk.Linear(3), jax.nn.relu])(x)

      return hk.Linear(1)(Alpha()(x) + Beta()(x))

    model = hk_util.transform(model_fun)
    x = np.random.randn(1, 12, 15, 1)
    params = model.init(jax.random.PRNGKey(0), x)

    summary = hk_util.summarize_model(params)
    self.assertEqual(
        summary, """
Variable         Shape            #
alpha/conv2_d.b  (8,)             8
alpha/conv2_d.w  (3, 3, 1, 8)    72
alpha/linear.w   (336, 3)      1008
beta/linear.b    (3,)             3
beta/linear.w    (180, 3)       540
linear.b         (1,)             1
linear.w         (3, 1)           3
Total                          1635
""".strip())
Esempio n. 4
0
  def test_model_workflow(self):
    meta = FooMetadata(hidden_units=[5, 2])
    model = hk_util.transform(functools.partial(foo_model, meta=meta))

    # Get some random param values.
    batch = {'x': jnp.array([[0.5, 1.0, -1.5]])}
    params = model.init(jax.random.PRNGKey(0), batch)

    # Associate params with the model to get a TrainedModel.
    trained_model = hk_util.TrainedModel(model, meta=meta, params=params)

    # Save and load the model.
    filename = '/tmp/hk_util_test/model.pkl'
    trained_model.save(filename)

    recovered = hk_util.TrainedModel.load(filename, foo_model, FooMetadata)

    # Check that meta, params, and model forward function are the same.
    self.assertEqual(recovered.meta, meta)
    self._assert_tree_equal(recovered.params, params)
    y = recovered(None, batch)
    expected_y = model.apply(params, None, batch)
    np.testing.assert_array_equal(y, expected_y)
Esempio n. 5
0
def train_model(meta: Metadata,
                dataset: phone_util.Dataset) -> hk_util.TrainedModel:
    """Train the model."""
    model = hk_util.transform(functools.partial(model_fun, meta=meta))

    # Split off a separate validation dataset.
    dataset_val, dataset_train = dataset.split(meta.validation_fraction)

    def generate_batches(dataset: phone_util.Dataset, batch_size: int):
        """Partition into batches. Examples in any partial batch are dropped."""
        x, y = dataset.get_xy_arrays(meta.classes, shuffle=True)

        batch_size = min(batch_size, len(x))
        num_batches = len(x) // batch_size
        batches_x = x[:num_batches * batch_size].reshape(
            num_batches, batch_size, *x.shape[1:])
        batches_y = y[:num_batches * batch_size].reshape(
            num_batches, batch_size)
        return batches_x, batches_y

    train_x, train_y = generate_batches(dataset_train,
                                        batch_size=meta.batch_size)
    t_eval_x, t_eval_y = generate_batches(dataset_train, batch_size=10000)
    t_eval_batch = {'observed': t_eval_x[0], 'label': t_eval_y[0]}
    v_eval_x, v_eval_y = generate_batches(dataset_val, batch_size=10000)
    v_eval_batch = {'observed': v_eval_x[0], 'label': v_eval_y[0]}

    # Initialize network and optimizer.
    seed = np.uint64(random.getrandbits(64))
    params = model.init(jax.random.PRNGKey(seed), {
        'observed': train_x[0],
        'label': train_y[0]
    })
    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(params)

    # Print model summary.
    print(hk_util.summarize_model(params))

    def loss_fun(params, batch):
        """Training loss to optimize."""
        outputs = model.apply(params, None, batch)
        labels = hk.one_hot(batch['label'], len(meta.classes))
        softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(outputs['scores']))
        softmax_xent /= labels.shape[0]
        disperse = embedding_regularizer(outputs['embedded'], batch['label'],
                                         meta)
        return softmax_xent + disperse + outputs['penalties']

    @jax.jit
    def train_step(params, opt_state, batch):
        """Learning update rule."""
        grads = jax.grad(loss_fun)(params, batch)
        updates, opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state

    @jax.jit
    def accuracy(params, batch):
        """Evaluate classification accuracy."""
        scores = model.apply(params, None, batch)['scores']
        return jnp.mean(jnp.argmax(scores, axis=-1) == batch['label'])

    # Training loop.
    num_steps = len(train_x) * meta.num_epochs
    step_digits = len(str(num_steps))
    step = 0
    for _ in range(meta.num_epochs):
        for batch_x, batch_y in zip(train_x, train_y):
            step += 1
            train_batch = {'observed': batch_x, 'label': batch_y}
            final_step = (step == num_steps)
            if final_step or step % 500 == 0:
                # Periodically evaluate classification accuracy on train & test sets.
                train_accuracy = accuracy(params, t_eval_batch)
                val_accuracy = accuracy(params, v_eval_batch)
                train_accuracy, val_accuracy = jax.device_get(
                    (train_accuracy, val_accuracy))
                print(f'[{step:-{step_digits}d}/{num_steps}] train acc = '
                      f'{train_accuracy:.4f}, val acc = {val_accuracy:.4f}')

            params, opt_state = train_step(params, opt_state, train_batch)

    return hk_util.TrainedModel(model, meta=meta, params=params)