def test_evaluate_model(self): # Mock out Model. model = models.Model( init=lambda rng: None, # Unused. apply_for_train=lambda params, batch, rng: None, # Unused. apply_for_eval=lambda params, batch: batch.get('pred'), train_loss=lambda batch, preds: None, # Unused. eval_metrics={ 'accuracy': metrics.Accuracy(), 'loss': metrics.CrossEntropyLoss(), }) params = jnp.array(3.14) # Unused. batches = [ { 'y': np.array([1, 0, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) }, { 'y': np.array([0, 1, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) }, { 'y': np.array([0, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1]]) }, ] eval_results = models.evaluate_model(model, params, batches) self.assertEqual(eval_results['accuracy'], 0.625) # 5 / 8. self.assertAlmostEqual(eval_results['loss'], 0.8419596, places=6)
def test_evaluate_per_client_params(self): # Mock out Model. model = models.Model( init=lambda rng: None, # Unused. apply_for_train=lambda params, batch, rng: None, # Unused. apply_for_eval=lambda params, batch: batch.get('pred') + params, train_loss=lambda batch, preds: None, # Unused. eval_metrics={ 'accuracy': metrics.Accuracy(), 'loss': metrics.CrossEntropyLoss(), }) clients = [ (b'0000', [{ 'y': np.array([1, 0, 1]), 'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]]) }, { 'y': np.array([0, 1, 1]), 'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]]) }], jnp.array([1, -1])), (b'1001', [{ 'y': np.array([0, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1]]) }], jnp.array([0, 0])), ] eval_results = dict( models.ModelEvaluator(model).evaluate_per_client_params(clients)) self.assertCountEqual(eval_results, [b'0000', b'1001']) self.assertCountEqual(eval_results[b'0000'], ['accuracy', 'loss']) npt.assert_allclose(eval_results[b'0000']['accuracy'], 4 / 6) npt.assert_allclose(eval_results[b'0000']['loss'], 0.67658216) self.assertCountEqual(eval_results[b'1001'], ['accuracy', 'loss']) npt.assert_allclose(eval_results[b'1001']['accuracy'], 1 / 2) npt.assert_allclose(eval_results[b'1001']['loss'], 1.338092)
def fake_model(): def apply_for_eval(params, example): del params return jax.nn.one_hot(example['x'] % 3, 3) eval_metrics = {'accuracy': metrics.Accuracy()} return models.Model( init=None, apply_for_train=None, apply_for_eval=apply_for_eval, train_loss=None, eval_metrics=eval_metrics)
def test_model_per_example_loss(self): # Mock out Model. model = models.Model( init=lambda rng: None, # Unused. apply_for_train=lambda params, batch, rng: batch['x'] * params + rng, apply_for_eval=lambda params, batch: None, # Unused. train_loss=lambda batch, preds: jnp.abs(batch['y'] - preds), eval_metrics={} # Unused ) params = jnp.array(2.) batch = { 'x': jnp.array([1., -1., 1.]), 'y': jnp.array([0.1, -0.1, -0.1]) } rng = jnp.array(0.5) loss = models.model_per_example_loss(model)(params, batch, rng) npt.assert_allclose(loss, [2.4, 1.4, 2.6])
def test_model_grad(self): # Mock out Model. model = models.Model( init=lambda rng: None, # Unused. apply_for_train=lambda params, batch, rng: batch['x'] * params + rng, apply_for_eval=lambda params, batch: None, # Unused. train_loss=lambda batch, preds: jnp.square(batch['y'] - preds) / 2, eval_metrics={} # Unused ) params = jnp.array(2.) batch = { 'x': jnp.array([1., -1., 1.]), 'y': jnp.array([0.1, -0.1, -0.1]) } rng = jnp.array(0.5) with self.subTest('no regularizer'): grads = models.model_grad(model)(params, batch, rng) npt.assert_allclose(grads, (2.4 + 1.4 + 2.6) / 3) with self.subTest('has regularizer'): grads = models.model_grad(model, jnp.abs)(params, batch, rng) npt.assert_allclose(grads, (2.4 + 1.4 + 2.6) / 3 + 1) with self.subTest('has mask'): grads = models.model_grad(model)( params, { **batch, '__mask__': jnp.array([True, False, True]) }, rng) npt.assert_allclose(grads, (2.4 + 2.6) / 2) with self.subTest('has regularizer, has mask'): grads = models.model_grad(model, jnp.abs)( params, { **batch, '__mask__': jnp.array([True, False, True]) }, rng) npt.assert_allclose(grads, (2.4 + 2.6) / 2 + 1)
def test_hyp_cluster_evaluator(self): functions_called = set() def apply_for_eval(params, batch): functions_called.add('apply_for_eval') score = params * batch['x'] return jnp.stack([-score, score], axis=-1) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params * batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.abs(batch['y'] * 2 - 1 - out) def regularizer(params): # Just to check regularizer is called. del params functions_called.add('regularizer') return 0 evaluator = hyp_cluster.HypClusterEvaluator( models.Model(init=None, apply_for_eval=apply_for_eval, apply_for_train=apply_for_train, train_loss=train_loss, eval_metrics={'accuracy': metrics.Accuracy()}), regularizer) cluster_params = [jnp.array(1.), jnp.array(-1.)] train_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([1, 1, 0]) }), jax.random.PRNGKey(0)), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([0, 1, 0, 1, 0]) }), jax.random.PRNGKey(1)), ] # Test clients are generated from train_clients by swapping client ids and # then flipping labels. test_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([1, 0, 1, 0, 1]) })), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([0, 0, 1]) })), ] for batch_size in [1, 2, 4]: with self.subTest(f'batch_size = {batch_size}'): batch_hparams = client_datasets.PaddedBatchHParams( batch_size=batch_size) metric_values = dict( evaluator.evaluate_clients(cluster_params=cluster_params, train_clients=train_clients, test_clients=test_clients, batch_hparams=batch_hparams)) self.assertCountEqual(metric_values, [b'0', b'1']) self.assertCountEqual(metric_values[b'0'], ['accuracy']) npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5) self.assertCountEqual(metric_values[b'1'], ['accuracy']) npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3) self.assertCountEqual( functions_called, ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
def test_kmeans_init(self): functions_called = set() def init(rng): functions_called.add('init') return jax.random.uniform(rng) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params - batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.square(out) + batch['bias'] def regularizer(params): del params functions_called.add('regularizer') return 0 initializer = hyp_cluster.ModelKMeansInitializer( models.Model(init=init, apply_for_train=apply_for_train, apply_for_eval=None, train_loss=train_loss, eval_metrics={}), optimizers.sgd(0.5), regularizer) # Each client has 1 example, so it's very easy to reach minimal loss, at # which point the loss entirely depends on bias. clients = [ (b'0', client_datasets.ClientDataset({ 'x': np.array([1.01]), 'bias': np.array([-2.]) }), jax.random.PRNGKey(1)), (b'1', client_datasets.ClientDataset({ 'x': np.array([3.02]), 'bias': np.array([-1.]) }), jax.random.PRNGKey(2)), (b'2', client_datasets.ClientDataset({ 'x': np.array([3.03]), 'bias': np.array([1.]) }), jax.random.PRNGKey(3)), (b'3', client_datasets.ClientDataset({ 'x': np.array([1.04]), 'bias': np.array([2.]) }), jax.random.PRNGKey(3)), ] train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams( batch_size=1, num_epochs=5) eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2) # Using a rng that leads to b'0' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(0), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04]) self.assertCountEqual( functions_called, ['init', 'apply_for_train', 'train_loss', 'regularizer']) # Using a rng that leads to b'2' being the initial center. cluster_params = initializer.cluster_params( num_clusters=3, rng=jax.random.PRNGKey(1), clients=clients, train_batch_hparams=train_batch_hparams, eval_batch_hparams=eval_batch_hparams) self.assertIsInstance(cluster_params, list) self.assertLen(cluster_params, 3) npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])