コード例 #1
0
 def test_evaluate_per_client_params(self):
     # Mock out Model.
     model = models.Model(
         init=lambda rng: None,  # Unused.
         apply_for_train=lambda params, batch, rng: None,  # Unused.
         apply_for_eval=lambda params, batch: batch.get('pred') + params,
         train_loss=lambda batch, preds: None,  # Unused.
         eval_metrics={
             'accuracy': metrics.Accuracy(),
             'loss': metrics.CrossEntropyLoss(),
         })
     clients = [
         (b'0000', [{
             'y': np.array([1, 0, 1]),
             'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]])
         }, {
             'y': np.array([0, 1, 1]),
             'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]])
         }], jnp.array([1, -1])),
         (b'1001', [{
             'y': np.array([0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1]])
         }], jnp.array([0, 0])),
     ]
     eval_results = dict(
         models.ModelEvaluator(model).evaluate_per_client_params(clients))
     self.assertCountEqual(eval_results, [b'0000', b'1001'])
     self.assertCountEqual(eval_results[b'0000'], ['accuracy', 'loss'])
     npt.assert_allclose(eval_results[b'0000']['accuracy'], 4 / 6)
     npt.assert_allclose(eval_results[b'0000']['loss'], 0.67658216)
     self.assertCountEqual(eval_results[b'1001'], ['accuracy', 'loss'])
     npt.assert_allclose(eval_results[b'1001']['accuracy'], 1 / 2)
     npt.assert_allclose(eval_results[b'1001']['loss'], 1.338092)
コード例 #2
0
 def test_evaluate_model(self):
     # Mock out Model.
     model = models.Model(
         init=lambda rng: None,  # Unused.
         apply_for_train=lambda params, batch, rng: None,  # Unused.
         apply_for_eval=lambda params, batch: batch.get('pred'),
         train_loss=lambda batch, preds: None,  # Unused.
         eval_metrics={
             'accuracy': metrics.Accuracy(),
             'loss': metrics.CrossEntropyLoss(),
         })
     params = jnp.array(3.14)  # Unused.
     batches = [
         {
             'y': np.array([1, 0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
         },
         {
             'y': np.array([0, 1, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
         },
         {
             'y': np.array([0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1]])
         },
     ]
     eval_results = models.evaluate_model(model, params, batches)
     self.assertEqual(eval_results['accuracy'], 0.625)  # 5 / 8.
     self.assertAlmostEqual(eval_results['loss'], 0.8419596, places=6)
コード例 #3
0
ファイル: metrics_test.py プロジェクト: google/fedjax
 def test_cross_entropy_loss(self):
     example = {'y': jnp.array(1)}
     prediction = jnp.array([1.2, 0.4])
     metric = metrics.CrossEntropyLoss()
     with self.subTest('zero'):
         zero = metric.zero()
         self.assertEqual(zero.accum, 0)
         self.assertEqual(zero.weight, 0)
     with self.subTest('evaluate_example'):
         loss = metric.evaluate_example(example, prediction)
         self.assertAlmostEqual(loss.accum, 1.1711007)
         self.assertAlmostEqual(loss.weight, 1)
コード例 #4
0
ファイル: emnist.py プロジェクト: google/fedjax
        x = Dropout(rate=0.5)(x, is_train)
        x = hk.Linear(self._num_classes)(x)
        return x


# Defines the expected structure of input batches to the model. This is used to
# determine the model parameter shapes.
_HAIKU_SAMPLE_BATCH = {
    'x': np.zeros((1, 28, 28, 1), dtype=np.float32),
    'y': np.zeros(1, dtype=np.float32)
}
_STAX_SAMPLE_SHAPE = (-1, 28, 28, 1)

_TRAIN_LOSS = lambda b, p: metrics.unreduced_cross_entropy_loss(b['y'], p)
_EVAL_METRICS = {
    'loss': metrics.CrossEntropyLoss(),
    'accuracy': metrics.Accuracy()
}


def create_conv_model(only_digits: bool = False) -> models.Model:
    """Creates EMNIST CNN model with dropout with haiku.

  Matches the model used in:

  Adaptive Federated Optimization
    Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan.
    https://arxiv.org/abs/2003.00295

  Args: