Ejemplo n.º 1
0
 def test_evaluate_model(self):
     # Mock out Model.
     model = models.Model(
         init=lambda rng: None,  # Unused.
         apply_for_train=lambda params, batch, rng: None,  # Unused.
         apply_for_eval=lambda params, batch: batch.get('pred'),
         train_loss=lambda batch, preds: None,  # Unused.
         eval_metrics={
             'accuracy': metrics.Accuracy(),
             'loss': metrics.CrossEntropyLoss(),
         })
     params = jnp.array(3.14)  # Unused.
     batches = [
         {
             'y': np.array([1, 0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
         },
         {
             'y': np.array([0, 1, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
         },
         {
             'y': np.array([0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1]])
         },
     ]
     eval_results = models.evaluate_model(model, params, batches)
     self.assertEqual(eval_results['accuracy'], 0.625)  # 5 / 8.
     self.assertAlmostEqual(eval_results['loss'], 0.8419596, places=6)
Ejemplo n.º 2
0
 def test_evaluate_per_client_params(self):
     # Mock out Model.
     model = models.Model(
         init=lambda rng: None,  # Unused.
         apply_for_train=lambda params, batch, rng: None,  # Unused.
         apply_for_eval=lambda params, batch: batch.get('pred') + params,
         train_loss=lambda batch, preds: None,  # Unused.
         eval_metrics={
             'accuracy': metrics.Accuracy(),
             'loss': metrics.CrossEntropyLoss(),
         })
     clients = [
         (b'0000', [{
             'y': np.array([1, 0, 1]),
             'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]])
         }, {
             'y': np.array([0, 1, 1]),
             'pred': np.array([[0.2, 1.4], [1.3, 1.1], [-0.7, 4.2]])
         }], jnp.array([1, -1])),
         (b'1001', [{
             'y': np.array([0, 1]),
             'pred': np.array([[1.2, 0.4], [2.3, 0.1]])
         }], jnp.array([0, 0])),
     ]
     eval_results = dict(
         models.ModelEvaluator(model).evaluate_per_client_params(clients))
     self.assertCountEqual(eval_results, [b'0000', b'1001'])
     self.assertCountEqual(eval_results[b'0000'], ['accuracy', 'loss'])
     npt.assert_allclose(eval_results[b'0000']['accuracy'], 4 / 6)
     npt.assert_allclose(eval_results[b'0000']['loss'], 0.67658216)
     self.assertCountEqual(eval_results[b'1001'], ['accuracy', 'loss'])
     npt.assert_allclose(eval_results[b'1001']['accuracy'], 1 / 2)
     npt.assert_allclose(eval_results[b'1001']['loss'], 1.338092)
Ejemplo n.º 3
0
def fake_model():

  def apply_for_eval(params, example):
    del params
    return jax.nn.one_hot(example['x'] % 3, 3)

  eval_metrics = {'accuracy': metrics.Accuracy()}
  return models.Model(
      init=None,
      apply_for_train=None,
      apply_for_eval=apply_for_eval,
      train_loss=None,
      eval_metrics=eval_metrics)
Ejemplo n.º 4
0
    def test_model_per_example_loss(self):
        # Mock out Model.
        model = models.Model(
            init=lambda rng: None,  # Unused.
            apply_for_train=lambda params, batch, rng: batch['x'] * params +
            rng,
            apply_for_eval=lambda params, batch: None,  # Unused.
            train_loss=lambda batch, preds: jnp.abs(batch['y'] - preds),
            eval_metrics={}  # Unused
        )

        params = jnp.array(2.)
        batch = {
            'x': jnp.array([1., -1., 1.]),
            'y': jnp.array([0.1, -0.1, -0.1])
        }
        rng = jnp.array(0.5)

        loss = models.model_per_example_loss(model)(params, batch, rng)
        npt.assert_allclose(loss, [2.4, 1.4, 2.6])
Ejemplo n.º 5
0
    def test_model_grad(self):
        # Mock out Model.
        model = models.Model(
            init=lambda rng: None,  # Unused.
            apply_for_train=lambda params, batch, rng: batch['x'] * params +
            rng,
            apply_for_eval=lambda params, batch: None,  # Unused.
            train_loss=lambda batch, preds: jnp.square(batch['y'] - preds) / 2,
            eval_metrics={}  # Unused
        )

        params = jnp.array(2.)
        batch = {
            'x': jnp.array([1., -1., 1.]),
            'y': jnp.array([0.1, -0.1, -0.1])
        }
        rng = jnp.array(0.5)

        with self.subTest('no regularizer'):
            grads = models.model_grad(model)(params, batch, rng)
            npt.assert_allclose(grads, (2.4 + 1.4 + 2.6) / 3)

        with self.subTest('has regularizer'):
            grads = models.model_grad(model, jnp.abs)(params, batch, rng)
            npt.assert_allclose(grads, (2.4 + 1.4 + 2.6) / 3 + 1)

        with self.subTest('has mask'):
            grads = models.model_grad(model)(
                params, {
                    **batch, '__mask__': jnp.array([True, False, True])
                }, rng)
            npt.assert_allclose(grads, (2.4 + 2.6) / 2)

        with self.subTest('has regularizer, has mask'):
            grads = models.model_grad(model, jnp.abs)(
                params, {
                    **batch, '__mask__': jnp.array([True, False, True])
                }, rng)
            npt.assert_allclose(grads, (2.4 + 2.6) / 2 + 1)
Ejemplo n.º 6
0
    def test_hyp_cluster_evaluator(self):
        functions_called = set()

        def apply_for_eval(params, batch):
            functions_called.add('apply_for_eval')
            score = params * batch['x']
            return jnp.stack([-score, score], axis=-1)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params * batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.abs(batch['y'] * 2 - 1 - out)

        def regularizer(params):
            # Just to check regularizer is called.
            del params
            functions_called.add('regularizer')
            return 0

        evaluator = hyp_cluster.HypClusterEvaluator(
            models.Model(init=None,
                         apply_for_eval=apply_for_eval,
                         apply_for_train=apply_for_train,
                         train_loss=train_loss,
                         eval_metrics={'accuracy': metrics.Accuracy()}),
            regularizer)

        cluster_params = [jnp.array(1.), jnp.array(-1.)]
        train_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([1, 1, 0])
             }), jax.random.PRNGKey(0)),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([0, 1, 0, 1, 0])
             }), jax.random.PRNGKey(1)),
        ]
        # Test clients are generated from train_clients by swapping client ids and
        # then flipping labels.
        test_clients = [
            # Evaluated using cluster 0.
            (b'0',
             client_datasets.ClientDataset({
                 'x':
                 np.array([0.9, -0.9, 0.8, -0.8, -0.3]),
                 'y':
                 np.array([1, 0, 1, 0, 1])
             })),
            # Evaluated using cluster 1.
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3., 2., 1.]),
                 'y': np.array([0, 0, 1])
             })),
        ]
        for batch_size in [1, 2, 4]:
            with self.subTest(f'batch_size = {batch_size}'):
                batch_hparams = client_datasets.PaddedBatchHParams(
                    batch_size=batch_size)
                metric_values = dict(
                    evaluator.evaluate_clients(cluster_params=cluster_params,
                                               train_clients=train_clients,
                                               test_clients=test_clients,
                                               batch_hparams=batch_hparams))
                self.assertCountEqual(metric_values, [b'0', b'1'])
                self.assertCountEqual(metric_values[b'0'], ['accuracy'])
                npt.assert_allclose(metric_values[b'0']['accuracy'], 4 / 5)
                self.assertCountEqual(metric_values[b'1'], ['accuracy'])
                npt.assert_allclose(metric_values[b'1']['accuracy'], 2 / 3)
        self.assertCountEqual(
            functions_called,
            ['apply_for_train', 'train_loss', 'apply_for_eval', 'regularizer'])
Ejemplo n.º 7
0
    def test_kmeans_init(self):
        functions_called = set()

        def init(rng):
            functions_called.add('init')
            return jax.random.uniform(rng)

        def apply_for_train(params, batch, rng):
            functions_called.add('apply_for_train')
            self.assertIsNotNone(rng)
            return params - batch['x']

        def train_loss(batch, out):
            functions_called.add('train_loss')
            return jnp.square(out) + batch['bias']

        def regularizer(params):
            del params
            functions_called.add('regularizer')
            return 0

        initializer = hyp_cluster.ModelKMeansInitializer(
            models.Model(init=init,
                         apply_for_train=apply_for_train,
                         apply_for_eval=None,
                         train_loss=train_loss,
                         eval_metrics={}), optimizers.sgd(0.5), regularizer)
        # Each client has 1 example, so it's very easy to reach minimal loss, at
        # which point the loss entirely depends on bias.
        clients = [
            (b'0',
             client_datasets.ClientDataset({
                 'x': np.array([1.01]),
                 'bias': np.array([-2.])
             }), jax.random.PRNGKey(1)),
            (b'1',
             client_datasets.ClientDataset({
                 'x': np.array([3.02]),
                 'bias': np.array([-1.])
             }), jax.random.PRNGKey(2)),
            (b'2',
             client_datasets.ClientDataset({
                 'x': np.array([3.03]),
                 'bias': np.array([1.])
             }), jax.random.PRNGKey(3)),
            (b'3',
             client_datasets.ClientDataset({
                 'x': np.array([1.04]),
                 'bias': np.array([2.])
             }), jax.random.PRNGKey(3)),
        ]
        train_batch_hparams = client_datasets.ShuffleRepeatBatchHParams(
            batch_size=1, num_epochs=5)
        eval_batch_hparams = client_datasets.PaddedBatchHParams(batch_size=2)
        # Using a rng that leads to b'0' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(0),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [1.01, 3.03, 1.04])
        self.assertCountEqual(
            functions_called,
            ['init', 'apply_for_train', 'train_loss', 'regularizer'])
        # Using a rng that leads to b'2' being the initial center.
        cluster_params = initializer.cluster_params(
            num_clusters=3,
            rng=jax.random.PRNGKey(1),
            clients=clients,
            train_batch_hparams=train_batch_hparams,
            eval_batch_hparams=eval_batch_hparams)
        self.assertIsInstance(cluster_params, list)
        self.assertLen(cluster_params, 3)
        npt.assert_allclose(cluster_params, [3.03, 1.04, 1.04])