def test_evaluate_model(self): # Mock out Model. model = models.Model( init=lambda rng: None, # Unused. apply_for_train=lambda params, batch, rng: None, # Unused. apply_for_eval=lambda params, batch: batch.get('pred'), train_loss=lambda batch, preds: None, # Unused. eval_metrics={ 'accuracy': metrics.Accuracy(), 'loss': metrics.CrossEntropyLoss(), }) params = jnp.array(3.14) # Unused. batches = [ { 'y': np.array([1, 0, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) }, { 'y': np.array([0, 1, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) }, { 'y': np.array([0, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1]]) }, ] eval_results = models.evaluate_model(model, params, batches) self.assertEqual(eval_results['accuracy'], 0.625) # 5 / 8. self.assertAlmostEqual(eval_results['loss'], 0.8419596, places=6)
def test_evaluate_per_client_params(self): # Mock out Model. model = models.Model( init=lambda rng: None, # Unused. apply_for_train=lambda params, batch, rng: None, # Unused. apply_for_eval=lambda params, batch: batch.get('pred') + params, train_loss=lambda batch, preds: None, # Unused. eval_metrics={ 'accuracy': metrics.Accuracy(), 'loss': metrics.CrossEntropyLoss(), }) clients = [ (b'0000', [{ 'y': np.array([1, 0, 1]), 'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]]) }, { 'y': np.array([0, 1, 1]), 'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]]) }], jnp.array([1, -1])), (b'1001', [{ 'y': np.array([0, 1]), 'pred': np.array([[1.2, 0.4], [2.3, 0.1]]) }], jnp.array([0, 0])), ] eval_results = dict( models.ModelEvaluator(model).evaluate_per_client_params(clients)) self.assertCountEqual(eval_results, [b'0000', b'1001']) self.assertCountEqual(eval_results[b'0000'], ['accuracy', 'loss']) npt.assert_allclose(eval_results[b'0000']['accuracy'], 4 / 6) npt.assert_allclose(eval_results[b'0000']['loss'], 0.67658216) self.assertCountEqual(eval_results[b'1001'], ['accuracy', 'loss']) npt.assert_allclose(eval_results[b'1001']['accuracy'], 1 / 2) npt.assert_allclose(eval_results[b'1001']['loss'], 1.338092)
def test_accuracy(self, target, prediction, expected_result): example = {'y': jnp.array(target)} prediction = jnp.array(prediction) metric = metrics.Accuracy() with self.subTest('zero'): zero = metric.zero() self.assertEqual(zero.accum, 0) self.assertEqual(zero.weight, 0) with self.subTest('evaluate_example'): accuracy = metric.evaluate_example(example, prediction) self.assertEqual(accuracy.result(), expected_result)
def fake_model(): def apply_for_eval(params, example): del params return jax.nn.one_hot(example['x'] % 3, 3) eval_metrics = {'accuracy': metrics.Accuracy()} return models.Model( init=None, apply_for_train=None, apply_for_eval=apply_for_eval, train_loss=None, eval_metrics=eval_metrics)
def test_per_domain(self): metric = metrics.PerDomainMetric(metrics.Accuracy(), num_domains=4) stat = metric.zero() self.assertIsInstance(stat, metrics.MeanStat) npt.assert_array_equal(stat.accum, [0., 0., 0., 0.]) npt.assert_array_equal(stat.weight, [0., 0., 0., 0.]) stat = metric.evaluate_example({ 'y': jnp.array(1), 'domain_id': 0 }, jnp.array([0., 1.])) self.assertIsInstance(stat, metrics.MeanStat) npt.assert_array_equal(stat.accum, [1., 0., 0., 0.]) npt.assert_array_equal(stat.weight, [1., 0., 0., 0.]) stat = metric.evaluate_example({ 'y': jnp.array(0), 'domain_id': 2 }, jnp.array([0., 1.])) self.assertIsInstance(stat, metrics.MeanStat) npt.assert_array_equal(stat.accum, [0., 0., 0., 0.]) npt.assert_array_equal(stat.weight, [0., 0., 1., 0.])
def test_hyp_cluster_evaluator(self): functions_called = set() def apply_for_eval(params, batch): functions_called.add('apply_for_eval') score = params * batch['x'] return jnp.stack([-score, score], axis=-1) def apply_for_train(params, batch, rng): functions_called.add('apply_for_train') self.assertIsNotNone(rng) return params * batch['x'] def train_loss(batch, out): functions_called.add('train_loss') return jnp.abs(batch['y'] * 2 - 1 - out) def regularizer(params): # Just to check regularizer is called. del params functions_called.add('regularizer') return 0 evaluator = hyp_cluster.HypClusterEvaluator( models.Model(init=None, apply_for_eval=apply_for_eval, apply_for_train=apply_for_train, train_loss=train_loss, eval_metrics={'accuracy': metrics.Accuracy()}), regularizer) cluster_params = [jnp.array(1.), jnp.array(-1.)] train_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([1, 1, 0]) }), jax.random.PRNGKey(0)), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([0, 1, 0, 1, 0]) }), jax.random.PRNGKey(1)), ] # Test clients are generated from train_clients by swapping client ids and # then flipping labels. test_clients = [ # Evaluated using cluster 0. (b'0', client_datasets.ClientDataset({ 'x': np.array([0.9, -0.9, 0.8, -0.8, -0.3]), 'y': np.array([1, 0, 1, 0, 1]) })), # Evaluated using cluster 1. (b'1', client_datasets.ClientDataset({ 'x': np.array([3., 2., 1.]), 'y': np.array([0, 0, 1]) })), ] for batch_size in [1, 2, 4]: with self.subTest(f'batch_size = {batch_size}'): batch_hparams = client_datasets.PaddedBatchHParams( batch_size=batch_size) metric_values = dict( evaluator.evaluate_clients(cluster_params=cluster_params, train_clients=train_clients, test_clients=test_clients, batch_hparams=batch_hparams)) self.assertCountEqual(metric_values, [b'0', b'1']) self.assertCountEqual(metric_values[b'0'], ['accuracy']) npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5) self.assertCountEqual(metric_values[b'1'], ['accuracy']) npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3) self.assertCountEqual( functions_called, ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
x = hk.Linear(self._num_classes)(x) return x # Defines the expected structure of input batches to the model. This is used to # determine the model parameter shapes. _HAIKU_SAMPLE_BATCH = { 'x': np.zeros((1, 28, 28, 1), dtype=np.float32), 'y': np.zeros(1, dtype=np.float32) } _STAX_SAMPLE_SHAPE = (-1, 28, 28, 1) _TRAIN_LOSS = lambda b, p: metrics.unreduced_cross_entropy_loss(b['y'], p) _EVAL_METRICS = { 'loss': metrics.CrossEntropyLoss(), 'accuracy': metrics.Accuracy() } def create_conv_model(only_digits: bool = False) -> models.Model: """Creates EMNIST CNN model with dropout with haiku. Matches the model used in: Adaptive Federated Optimization Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295 Args: only_digits: Whether to use only digit classes [0-9] or include lower and
from fedjax.core import metrics from fedjax.core import models import haiku as hk import jax try: from jax.example_libraries import stax except ModuleNotFoundError: from jax.experimental import stax import jax.numpy as jnp import numpy as np import numpy.testing as npt train_loss = lambda b, p: metrics.unreduced_cross_entropy_loss(b['y'], p) eval_metrics = {'accuracy': metrics.Accuracy()} class ModelTest(absltest.TestCase): def check_model(self, model): with self.subTest('init'): params = model.init(jax.random.PRNGKey(0)) num_params = sum(l.size for l in jax.tree_util.tree_leaves(params)) self.assertEqual(num_params, 30) with self.subTest('apply_for_train'): batch = { 'x': np.array([[1, 2], [3, 4], [5, 6]]), 'y': np.array([7, 8, 9]) } preds = model.apply_for_train(params, batch, jax.random.PRNGKey(0))