def test_linear_no_bias(self): """Test hk_util.Linear() without bias term.""" def model_fun(x): return hk_util.Linear( 5, use_bias=False, w_init=hk.initializers.TruncatedNormal(1.0), w_regularizer=functools.partial(hk_util.h1_loss, axis=0), name='foo')(x) model = hk_util.transform(model_fun) x = jnp.array([-0.3, 0.5, 4.0], np.float32) params = model.init(jax.random.PRNGKey(0), x) y, penalties = model.apply(params, None, x) # Check params structure. self.assertIn('foo', params) self.assertIn('w', params['foo']) self.assertNotIn('b', params['foo']) self.assertEqual(params['foo']['w'].shape, (3, 5)) # Check `y` output value. expected_y = jnp.dot(x, params['foo']['w']) np.testing.assert_allclose(y, expected_y, atol=1e-6) # Check penalties. expected_penalties = hk_util.h1_loss(params['foo']['w'], axis=0) np.testing.assert_allclose(penalties, expected_penalties, atol=1e-6)
def test_regularized_training(self): """Test that adding regularization penalty to the training loss works.""" np.random.seed(0) # Set up the problem of recovering w given x and # y = x . w + noise # with the a priori assumption that w is sparse. There are fewer examples # than dimensions (x is a wide matrix), so the problem is underdetermined # without the sparsity assumption. num_examples, num_dim = 8, 10 x = np.random.randn(num_examples, num_dim).astype(np.float32) true_w = np.zeros((num_dim, 2), np.float32) true_w[[2, 4, 6], 0] = [1.0, 2.0, 3.0] true_w[[3, 5], 1] = [4.0, 5.0] y = np.dot(x, true_w) + 1e-3 * np.random.randn(num_examples, 2) # Get the least squares estimate for w. It isn't very accurate. least_squares_w = np.linalg.lstsq(x, y, rcond=None)[0] least_squares_w_error = hk_util.l2_loss(least_squares_w - true_w) # Get a better estimate by solving the L1 regularized problem # argmin_w ||x . w - y||_2^2 + c ||w||_1. w_regularizer = lambda w: 4.0 * hk_util.l1_loss(w) def model_fun(batch): x = batch['x'] return hk_util.Linear(2, use_bias=False, w_regularizer=w_regularizer)(x) model = hk_util.transform(model_fun) def loss_fun(params, batch): """Training loss with L1 regularization penalty term.""" y_predicted, penalties = model.apply(params, None, batch) return hk_util.l2_loss(y_predicted - batch['y']) + penalties batch = {'x': x, 'y': y} params = model.init(jax.random.PRNGKey(0), batch) optimizer = optax.chain( # Gradient descent with decreasing learning rate. optax.trace(decay=0.0, nesterov=False), optax.scale_by_schedule(lambda i: -0.05 / jnp.sqrt(1 + i))) opt_state = optimizer.init(params) @jax.jit def train_step(params, opt_state, batch): grads = jax.grad(loss_fun)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state for _ in range(1000): params, opt_state = train_step(params, opt_state, batch) l1_w = params['linear']['w'] l1_w_error = hk_util.l2_loss(l1_w - true_w).item() # The L1-regularized estimate is much more accurate. self.assertGreater(least_squares_w_error, 4.0) self.assertLess(l1_w_error, 1.0)
def test_summarize_model(self): def model_fun(x): """A model with two submodules.""" class Alpha(hk.Module): # Alpha submodule. def __call__(self, x): return hk.Sequential([ hk.Conv2D(8, (3, 3)), jax.nn.relu, hk.MaxPool((1, 2, 2, 1), (1, 2, 2, 1), 'VALID'), hk.Flatten(), hk.Linear(3, with_bias=False) ])(x) class Beta(hk.Module): # Beta submodule. def __call__(self, x): return hk.Sequential([hk.Flatten(), hk.Linear(3), jax.nn.relu])(x) return hk.Linear(1)(Alpha()(x) + Beta()(x)) model = hk_util.transform(model_fun) x = np.random.randn(1, 12, 15, 1) params = model.init(jax.random.PRNGKey(0), x) summary = hk_util.summarize_model(params) self.assertEqual( summary, """ Variable Shape # alpha/conv2_d.b (8,) 8 alpha/conv2_d.w (3, 3, 1, 8) 72 alpha/linear.w (336, 3) 1008 beta/linear.b (3,) 3 beta/linear.w (180, 3) 540 linear.b (1,) 1 linear.w (3, 1) 3 Total 1635 """.strip())
def test_model_workflow(self): meta = FooMetadata(hidden_units=[5, 2]) model = hk_util.transform(functools.partial(foo_model, meta=meta)) # Get some random param values. batch = {'x': jnp.array([[0.5, 1.0, -1.5]])} params = model.init(jax.random.PRNGKey(0), batch) # Associate params with the model to get a TrainedModel. trained_model = hk_util.TrainedModel(model, meta=meta, params=params) # Save and load the model. filename = '/tmp/hk_util_test/model.pkl' trained_model.save(filename) recovered = hk_util.TrainedModel.load(filename, foo_model, FooMetadata) # Check that meta, params, and model forward function are the same. self.assertEqual(recovered.meta, meta) self._assert_tree_equal(recovered.params, params) y = recovered(None, batch) expected_y = model.apply(params, None, batch) np.testing.assert_array_equal(y, expected_y)
def train_model(meta: Metadata, dataset: phone_util.Dataset) -> hk_util.TrainedModel: """Train the model.""" model = hk_util.transform(functools.partial(model_fun, meta=meta)) # Split off a separate validation dataset. dataset_val, dataset_train = dataset.split(meta.validation_fraction) def generate_batches(dataset: phone_util.Dataset, batch_size: int): """Partition into batches. Examples in any partial batch are dropped.""" x, y = dataset.get_xy_arrays(meta.classes, shuffle=True) batch_size = min(batch_size, len(x)) num_batches = len(x) // batch_size batches_x = x[:num_batches * batch_size].reshape( num_batches, batch_size, *x.shape[1:]) batches_y = y[:num_batches * batch_size].reshape( num_batches, batch_size) return batches_x, batches_y train_x, train_y = generate_batches(dataset_train, batch_size=meta.batch_size) t_eval_x, t_eval_y = generate_batches(dataset_train, batch_size=10000) t_eval_batch = {'observed': t_eval_x[0], 'label': t_eval_y[0]} v_eval_x, v_eval_y = generate_batches(dataset_val, batch_size=10000) v_eval_batch = {'observed': v_eval_x[0], 'label': v_eval_y[0]} # Initialize network and optimizer. seed = np.uint64(random.getrandbits(64)) params = model.init(jax.random.PRNGKey(seed), { 'observed': train_x[0], 'label': train_y[0] }) optimizer = optax.adam(1e-3) opt_state = optimizer.init(params) # Print model summary. print(hk_util.summarize_model(params)) def loss_fun(params, batch): """Training loss to optimize.""" outputs = model.apply(params, None, batch) labels = hk.one_hot(batch['label'], len(meta.classes)) softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(outputs['scores'])) softmax_xent /= labels.shape[0] disperse = embedding_regularizer(outputs['embedded'], batch['label'], meta) return softmax_xent + disperse + outputs['penalties'] @jax.jit def train_step(params, opt_state, batch): """Learning update rule.""" grads = jax.grad(loss_fun)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state @jax.jit def accuracy(params, batch): """Evaluate classification accuracy.""" scores = model.apply(params, None, batch)['scores'] return jnp.mean(jnp.argmax(scores, axis=-1) == batch['label']) # Training loop. num_steps = len(train_x) * meta.num_epochs step_digits = len(str(num_steps)) step = 0 for _ in range(meta.num_epochs): for batch_x, batch_y in zip(train_x, train_y): step += 1 train_batch = {'observed': batch_x, 'label': batch_y} final_step = (step == num_steps) if final_step or step % 500 == 0: # Periodically evaluate classification accuracy on train & test sets. train_accuracy = accuracy(params, t_eval_batch) val_accuracy = accuracy(params, v_eval_batch) train_accuracy, val_accuracy = jax.device_get( (train_accuracy, val_accuracy)) print(f'[{step:-{step_digits}d}/{num_steps}] train acc = ' f'{train_accuracy:.4f}, val acc = {val_accuracy:.4f}') params, opt_state = train_step(params, opt_state, train_batch) return hk_util.TrainedModel(model, meta=meta, params=params)