コード例 #1
0
 def test_evaluate_model(self):
     # Mock out Model.
     model = models.Model(
         init=lambda rng: None,  # Unused.
         apply_for_train=lambda params, batch, rng: None,  # Unused.
         apply_for_eval=lambda params, batch: batch.get('pred'),
         train_loss=lambda batch, preds: None,  # Unused.
         eval_metrics={
             'accuracy': metrics.Accuracy(),
             'loss': metrics.CrossEntropyLoss(),
         })
     params = jnp.array(3.14)  # Unused.
     batches = [
         {
             'y': np.array([1, 0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
         },
         {
             'y': np.array([0, 1, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
         },
         {
             'y': np.array([0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1]])
         },
     ]
     eval_results = models.evaluate_model(model, params, batches)
     self.assertEqual(eval_results['accuracy'], 0.625)  # 5 / 8.
     self.assertAlmostEqual(eval_results['loss'], 0.8419596, places=6)
コード例 #2
0
 def test_evaluate_per_client_params(self):
     # Mock out Model.
     model = models.Model(
         init=lambda rng: None,  # Unused.
         apply_for_train=lambda params, batch, rng: None,  # Unused.
         apply_for_eval=lambda params, batch: batch.get('pred') + params,
         train_loss=lambda batch, preds: None,  # Unused.
         eval_metrics={
             'accuracy': metrics.Accuracy(),
             'loss': metrics.CrossEntropyLoss(),
         })
     clients = [
         (b'0000', [{
             'y': np.array([1, 0, 1]),
             'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]])
         }, {
             'y': np.array([0, 1, 1]),
             'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]])
         }], jnp.array([1, -1])),
         (b'1001', [{
             'y': np.array([0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1]])
         }], jnp.array([0, 0])),
     ]
     eval_results = dict(
         models.ModelEvaluator(model).evaluate_per_client_params(clients))
     self.assertCountEqual(eval_results, [b'0000', b'1001'])
     self.assertCountEqual(eval_results[b'0000'], ['accuracy', 'loss'])
     npt.assert_allclose(eval_results[b'0000']['accuracy'], 4 / 6)
     npt.assert_allclose(eval_results[b'0000']['loss'], 0.67658216)
     self.assertCountEqual(eval_results[b'1001'], ['accuracy', 'loss'])
     npt.assert_allclose(eval_results[b'1001']['accuracy'], 1 / 2)
     npt.assert_allclose(eval_results[b'1001']['loss'], 1.338092)
コード例 #3
0
ファイル: metrics_test.py プロジェクト: google/fedjax
 def test_accuracy(self, target, prediction, expected_result):
     example = {'y': jnp.array(target)}
     prediction = jnp.array(prediction)
     metric = metrics.Accuracy()
     with self.subTest('zero'):
         zero = metric.zero()
         self.assertEqual(zero.accum, 0)
         self.assertEqual(zero.weight, 0)
     with self.subTest('evaluate_example'):
         accuracy = metric.evaluate_example(example, prediction)
         self.assertEqual(accuracy.result(), expected_result)
コード例 #4
0
def fake_model():

  def apply_for_eval(params, example):
    del params
    return jax.nn.one_hot(example['x'] % 3, 3)

  eval_metrics = {'accuracy': metrics.Accuracy()}
  return models.Model(
      init=None,
      apply_for_train=None,
      apply_for_eval=apply_for_eval,
      train_loss=None,
      eval_metrics=eval_metrics)
コード例 #5
0
ファイル: metrics_test.py プロジェクト: google/fedjax
    def test_per_domain(self):
        metric = metrics.PerDomainMetric(metrics.Accuracy(), num_domains=4)

        stat = metric.zero()
        self.assertIsInstance(stat, metrics.MeanStat)
        npt.assert_array_equal(stat.accum, [0., 0., 0., 0.])
        npt.assert_array_equal(stat.weight, [0., 0., 0., 0.])

        stat = metric.evaluate_example({
            'y': jnp.array(1),
            'domain_id': 0
        }, jnp.array([0., 1.]))
        self.assertIsInstance(stat, metrics.MeanStat)
        npt.assert_array_equal(stat.accum, [1., 0., 0., 0.])
        npt.assert_array_equal(stat.weight, [1., 0., 0., 0.])

        stat = metric.evaluate_example({
            'y': jnp.array(0),
            'domain_id': 2
        }, jnp.array([0., 1.]))
        self.assertIsInstance(stat, metrics.MeanStat)
        npt.assert_array_equal(stat.accum, [0., 0., 0., 0.])
        npt.assert_array_equal(stat.weight, [0., 0., 1., 0.])
コード例 #6
0
ファイル: hyp_cluster_test.py プロジェクト: google/fedjax
    def test_hyp_cluster_evaluator(self):
        functions_called = set()

        def apply_for_eval(params, batch):
            functions_called.add('apply_for_eval')
            score = params * batch['x']
            return jnp.stack([-score, score], axis=-1)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params * batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.abs(batch['y'] * 2 - 1 - out)

        def regularizer(params):
            # Just to check regularizer is called.
            del params
            functions_called.add('regularizer')
            return 0

        evaluator = hyp_cluster.HypClusterEvaluator(
            models.Model(init=None,
                         apply_for_eval=apply_for_eval,
                         apply_for_train=apply_for_train,
                         train_loss=train_loss,
                         eval_metrics={'accuracy': metrics.Accuracy()}),
            regularizer)

        cluster_params = [jnp.array(1.), jnp.array(-1.)]
        train_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([1, 1, 0])
             }), jax.random.PRNGKey(0)),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([0, 1, 0, 1, 0])
             }), jax.random.PRNGKey(1)),
        ]
        # Test clients are generated from train_clients by swapping client ids and
        # then flipping labels.
        test_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([1, 0, 1, 0, 1])
             })),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([0, 0, 1])
             })),
        ]
        for batch_size in [1, 2, 4]:
            with self.subTest(f'batch_size = {batch_size}'):
                batch_hparams = client_datasets.PaddedBatchHParams(
                    batch_size=batch_size)
                metric_values = dict(
                    evaluator.evaluate_clients(cluster_params=cluster_params,
                                               train_clients=train_clients,
                                               test_clients=test_clients,
                                               batch_hparams=batch_hparams))
                self.assertCountEqual(metric_values, [b'0', b'1'])
                self.assertCountEqual(metric_values[b'0'], ['accuracy'])
                npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5)
                self.assertCountEqual(metric_values[b'1'], ['accuracy'])
                npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3)
        self.assertCountEqual(
            functions_called,
            ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
コード例 #7
0
ファイル: emnist.py プロジェクト: google/fedjax
        x = hk.Linear(self._num_classes)(x)
        return x


# Defines the expected structure of input batches to the model. This is used to
# determine the model parameter shapes.
_HAIKU_SAMPLE_BATCH = {
    'x': np.zeros((1, 28, 28, 1), dtype=np.float32),
    'y': np.zeros(1, dtype=np.float32)
}
_STAX_SAMPLE_SHAPE = (-1, 28, 28, 1)

_TRAIN_LOSS = lambda b, p: metrics.unreduced_cross_entropy_loss(b['y'], p)
_EVAL_METRICS = {
    'loss': metrics.CrossEntropyLoss(),
    'accuracy': metrics.Accuracy()
}


def create_conv_model(only_digits: bool = False) -> models.Model:
    """Creates EMNIST CNN model with dropout with haiku.

  Matches the model used in:

  Adaptive Federated Optimization
    Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan.
    https://arxiv.org/abs/2003.00295

  Args:
    only_digits: Whether to use only digit classes [0-9] or include lower and
コード例 #8
0
from fedjax.core import metrics
from fedjax.core import models

import haiku as hk
import jax
try:
    from jax.example_libraries import stax
except ModuleNotFoundError:
    from jax.experimental import stax
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt

train_loss = lambda b, p: metrics.unreduced_cross_entropy_loss(b['y'], p)
eval_metrics = {'accuracy': metrics.Accuracy()}


class ModelTest(absltest.TestCase):
    def check_model(self, model):
        with self.subTest('init'):
            params = model.init(jax.random.PRNGKey(0))
            num_params = sum(l.size for l in jax.tree_util.tree_leaves(params))
            self.assertEqual(num_params, 30)

        with self.subTest('apply_for_train'):
            batch = {
                'x': np.array([[1, 2], [3, 4], [5, 6]]),
                'y': np.array([7, 8, 9])
            }
            preds = model.apply_for_train(params, batch, jax.random.PRNGKey(0))