def test_mnist_training_round_trip(self):
   it = canonical_form_utils.get_iterative_process_for_canonical_form(
       test_utils.get_mnist_training_example())
   cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
   new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertEqual(str(state1), str(state2))
   dummy_x = np.array([[0.5] * 784], dtype=np.float32)
   dummy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   alt_metrics = alt_round_1[1]
   self.assertAllEqual(
       anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state))
   self.assertAllEqual(
       anonymous_tuple.name_list(metrics),
       anonymous_tuple.name_list(alt_metrics))
   self.assertAllClose(state, alt_state)
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(it.next)),
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(new_it.next)))
    def test_with_temperature_sensor_example(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)

        state = it.initialize()
        self.assertLen(state, 1)
        self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
        self.assertEqual(state[0], 0)

        state, metrics, stats = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertLen(state, 1)
        self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
        self.assertEqual(state[0], 1)
        self.assertLen(metrics, 1)
        self.assertAllEqual(anonymous_tuple.name_list(metrics),
                            ['ratio_over_threshold'])
        self.assertEqual(metrics[0], 0.5)
        self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                              [1, 3])

        state, metrics, stats = it.next(state,
                                        [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllEqual(state, (2, ))
        self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
        self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
  def test_temperature_example_round_trip(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 0)

    state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 1)
    self.assertLen(metrics, 1)
    self.assertAllEqual(
        anonymous_tuple.name_list(metrics), ['ratio_over_threshold'])
    self.assertEqual(metrics[0], 0.5)
    self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                          [1, 3])

    state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllEqual(state, (2,))
    self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
    self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))
Exemple #4
0
    def test_returns_canonical_form_from_tff_learning_structure(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
        self.assertEqual(it.initialize.type_signature,
                         new_it.initialize.type_signature)
        # Notice next type_signatures need not be equal, since we may have appended
        # an empty tuple as client side-channel outputs if none existed
        self.assertEqual(it.next.type_signature.parameter,
                         new_it.next.type_signature.parameter)
        self.assertEqual(it.next.type_signature.result[0],
                         new_it.next.type_signature.result[0])
        self.assertEqual(it.next.type_signature.result[1],
                         new_it.next.type_signature.result[1])

        state1 = it.initialize()
        state2 = new_it.initialize()

        sample_batch = collections.OrderedDict(x=np.array([[1., 1.]],
                                                          dtype=np.float32),
                                               y=np.array([[0]],
                                                          dtype=np.int32))
        client_data = [sample_batch]

        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        state_names = anonymous_tuple.name_list(state)
        state_arrays = anonymous_tuple.flatten(state)
        metrics = round_1[1]
        metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)]
        metrics_arrays = anonymous_tuple.flatten(metrics)

        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        alt_state_names = anonymous_tuple.name_list(alt_state)
        alt_state_arrays = anonymous_tuple.flatten(alt_state)
        alt_metrics = alt_round_1[1]
        alt_metrics_names = [
            x[0] for x in anonymous_tuple.iter_elements(alt_metrics)
        ]
        alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics)

        self.assertEmpty(state.delta_aggregate_state)
        self.assertEmpty(state.model_broadcast_state)
        self.assertAllEqual(state_names, alt_state_names)
        self.assertAllEqual(metrics_names, alt_metrics_names)
        self.assertAllClose(state_arrays, alt_state_arrays)
        self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2])
        # Final metric is execution time
        self.assertAlmostEqual(metrics_arrays[2],
                               alt_metrics_arrays[2],
                               delta=1e-3)
Exemple #5
0
    def test_canonical_form_with_learning_structure_does_not_change_execution_of_iterative_process(
            self):
        ip_1 = construct_example_training_comp()
        cf = tff.backends.mapreduce.get_canonical_form_for_iterative_process(
            ip_1)
        ip_2 = tff.backends.mapreduce.get_iterative_process_for_canonical_form(
            cf)

        self.assertEqual(ip_1.initialize.type_signature,
                         ip_2.initialize.type_signature)
        # The next functions type_signatures may not be equal, since we may have
        # appended an empty tuple as client side-channel outputs if none existed.
        self.assertEqual(ip_1.next.type_signature.parameter,
                         ip_2.next.type_signature.parameter)
        self.assertEqual(ip_1.next.type_signature.result[0],
                         ip_2.next.type_signature.result[0])
        self.assertEqual(ip_1.next.type_signature.result[1],
                         ip_2.next.type_signature.result[1])

        sample_batch = collections.OrderedDict(
            x=np.array([[1., 1.]], dtype=np.float32),
            y=np.array([[0]], dtype=np.int32),
        )
        client_data = [sample_batch]
        state_1 = ip_1.initialize()
        server_state_1, server_output_1 = ip_1.next(state_1, [client_data])
        server_state_1_names = anonymous_tuple.name_list(server_state_1)
        server_state_1_arrays = anonymous_tuple.flatten(server_state_1)
        server_output_1_names = [
            x[0] for x in anonymous_tuple.iter_elements(server_output_1)
        ]
        server_output_1_arrays = anonymous_tuple.flatten(server_output_1)
        state_2 = ip_2.initialize()
        server_state_2, server_output_2, _ = ip_2.next(state_2, [client_data])
        server_state_2_names = anonymous_tuple.name_list(server_state_2)
        server_state_2_arrays = anonymous_tuple.flatten(server_state_2)
        server_output_2_names = [
            x[0] for x in anonymous_tuple.iter_elements(server_output_2)
        ]
        server_output_2_arrays = anonymous_tuple.flatten(server_output_2)

        self.assertEmpty(server_state_1.delta_aggregate_state)
        self.assertEmpty(server_state_1.model_broadcast_state)
        self.assertAllEqual(server_state_1_names, server_state_2_names)
        self.assertAllEqual(server_output_1_names, server_output_2_names)
        self.assertAllClose(server_state_1_arrays, server_state_2_arrays)
        self.assertAllClose(server_output_1_arrays[:2],
                            server_output_2_arrays[:2])

        execution_time_1 = server_output_1_arrays[2]
        execution_time_2 = server_output_2_arrays[2]

        self.assertAlmostEqual(execution_time_1, execution_time_2, delta=1e-3)
Exemple #6
0
 def assert_weight_lists_match(old_value, new_value):
     """Assert two flat lists of ndarrays or tensors match."""
     if isinstance(new_value, leaf_types) and isinstance(
             old_value, leaf_types):
         if (old_value.dtype != new_value.dtype
                 or old_value.shape != new_value.shape):
             raise TypeError('Element is not the same tensor type. old '
                             f'({old_value.dtype}, {old_value.shape}) != '
                             f'new ({new_value.dtype}, {new_value.shape})')
     elif (isinstance(new_value, collections.Sequence)
           and isinstance(old_value, anonymous_tuple.AnonymousTuple)):
         if anonymous_tuple.name_list(old_value):
             raise TypeError(
                 '`tff.learning` does not support named structures of '
                 'model weights. Received: {old_value}')
         if len(old_value) != len(new_value):
             raise TypeError(
                 'Model weights have different lengths: '
                 f'(old) {len(old_value)} != (new) {len(new_value)})\n'
                 f'Old values: {old_value}\nNew values: {new_value}')
         for old, new in zip(old_value, new_value):
             assert_weight_lists_match(old, new)
     else:
         raise TypeError(
             'Model weights structures contains types that cannot be '
             'handled.\nOld weights structure: {old}\n'
             'New weights structure: {new}\n'
             'Must be one of (int, float, np.ndarray, tf.Tensor, '
             'collections.Sequence)'.format(
                 old=tf.nest.map_structure(type, old_value),
                 new=tf.nest.map_structure(type, new_value)))
 def test_multiple_named_and_unnamed(self):
   v = [(None, 10), ('foo', 20), ('bar', 30)]
   x = anonymous_tuple.AnonymousTuple(v)
   self.assertLen(x, 3)
   self.assertEqual(x[0], 10)
   self.assertEqual(x[1], 20)
   self.assertEqual(x[2], 30)
   self.assertRaises(IndexError, lambda _: x[3], None)
   self.assertEqual(list(iter(x)), [10, 20, 30])
   self.assertEqual(dir(x), ['bar', 'foo'])
   self.assertEqual(anonymous_tuple.name_list(x), ['foo', 'bar'])
   self.assertEqual(x.foo, 20)
   self.assertEqual(x.bar, 30)
   self.assertRaises(AttributeError, lambda _: x.baz, None)
   self.assertEqual(
       x, anonymous_tuple.AnonymousTuple([(None, 10), ('foo', 20),
                                          ('bar', 30)]))
   self.assertNotEqual(
       x, anonymous_tuple.AnonymousTuple([('foo', 10), ('bar', 20),
                                          (None, 30)]))
   self.assertEqual(anonymous_tuple.to_elements(x), v)
   self.assertEqual(
       repr(x), 'AnonymousTuple([(None, 10), (\'foo\', 20), (\'bar\', 30)])')
   self.assertEqual(str(x), '<10,foo=20,bar=30>')
   with self.assertRaisesRegex(ValueError, 'unnamed'):
     anonymous_tuple.to_odict(x)
Exemple #8
0
 def _run_test(self, process, *, datasets, expected_num_examples):
     state = process.initialize()
     prev_loss = np.inf
     for _ in range(3):
         state, metric_outputs = process.next(state, datasets)
         self.assertEqual(anonymous_tuple.name_list(metric_outputs),
                          ['broadcast', 'aggregation', 'train'])
         self.assertEmpty(metric_outputs.broadcast)
         self.assertEmpty(metric_outputs.aggregation)
         train_metrics = metric_outputs.train
         self.assertEqual(train_metrics.num_examples, expected_num_examples)
         self.assertLess(train_metrics.loss, prev_loss)
         prev_loss = train_metrics.loss
def _default_from_tff_result_fn(record):
  """Converts AnonymousTuple to dict or list if possible."""
  if isinstance(record, anonymous_tuple.AnonymousTuple):
    try:
      record = record._asdict()
    except ValueError:
      # At least some of the fields in `record` were not named. If all of the
      # fields were not named, we can return a `list`. Otherwise `record`
      # is partially named, which is not supported.
      if anonymous_tuple.name_list(record):
        raise ValueError(
            'Cannot construct a default from a TFF result that '
            'has partially named fields. TFF result: {!s}'.format(record))
      record = [elt for _, elt in anonymous_tuple.iter_elements(record)]
  return record
Exemple #10
0
    def test_execute_empty_data(self):
        iterative_process = federated_sgd.build_federated_sgd_process(
            model_fn=model_examples.LinearRegression)

        # Results in empty dataset with correct types and shapes.
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(x=[[1.0, 2.0]], y=[[5.0]])).batch(
                5, drop_remainder=True)  # No batches of size 5 can be created.
        federated_ds = [ds] * 2

        server_state = iterative_process.initialize()
        first_state, metric_outputs = iterative_process.next(
            server_state, federated_ds)
        self.assertAllClose(list(first_state.model.trainable),
                            [[[0.0], [0.0]], 0.0])
        self.assertEqual(anonymous_tuple.name_list(metric_outputs),
                         ['broadcast', 'aggregation', 'train'])
        self.assertEmpty(metric_outputs.broadcast)
        self.assertEmpty(metric_outputs.aggregation)
        self.assertEqual(metric_outputs.train.num_examples, 0)
        self.assertTrue(tf.math.is_nan(metric_outputs.train.loss))