コード例 #1
0
    def test_returns_canonical_form_from_tff_learning_structure(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
        self.assertEqual(it.initialize.type_signature,
                         new_it.initialize.type_signature)
        # Notice next type_signatures need not be equal, since we may have appended
        # an empty tuple as client side-channel outputs if none existed
        self.assertEqual(it.next.type_signature.parameter,
                         new_it.next.type_signature.parameter)
        self.assertEqual(it.next.type_signature.result[0],
                         new_it.next.type_signature.result[0])
        self.assertEqual(it.next.type_signature.result[1],
                         new_it.next.type_signature.result[1])

        state1 = it.initialize()
        state2 = new_it.initialize()

        sample_batch = collections.OrderedDict(x=np.array([[1., 1.]],
                                                          dtype=np.float32),
                                               y=np.array([[0]],
                                                          dtype=np.int32))
        client_data = [sample_batch]

        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        state_names = anonymous_tuple.name_list(state)
        state_arrays = anonymous_tuple.flatten(state)
        metrics = round_1[1]
        metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)]
        metrics_arrays = anonymous_tuple.flatten(metrics)

        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        alt_state_names = anonymous_tuple.name_list(alt_state)
        alt_state_arrays = anonymous_tuple.flatten(alt_state)
        alt_metrics = alt_round_1[1]
        alt_metrics_names = [
            x[0] for x in anonymous_tuple.iter_elements(alt_metrics)
        ]
        alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics)

        self.assertEmpty(state.delta_aggregate_state)
        self.assertEmpty(state.model_broadcast_state)
        self.assertAllEqual(state_names, alt_state_names)
        self.assertAllEqual(metrics_names, alt_metrics_names)
        self.assertAllClose(state_arrays, alt_state_arrays)
        self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2])
        # Final metric is execution time
        self.assertAlmostEqual(metrics_arrays[2],
                               alt_metrics_arrays[2],
                               delta=1e-3)
コード例 #2
0
    def test_get_canonical_form_from_fl_api(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
        self.assertEqual(it.initialize.type_signature,
                         new_it.initialize.type_signature)
        # Notice next type_signatures need not be equal, since we may have appended
        # an empty tuple as client side-channel outputs if none existed
        self.assertEqual(it.next.type_signature.parameter,
                         new_it.next.type_signature.parameter)
        self.assertEqual(it.next.type_signature.result[0],
                         new_it.next.type_signature.result[0])
        self.assertEqual(it.next.type_signature.result[1],
                         new_it.next.type_signature.result[1])

        state1 = it.initialize()
        state2 = new_it.initialize()
        self.assertEqual(str(state1), str(state2))

        sample_batch = collections.OrderedDict([
            ('x', np.array([[1., 1.]], dtype=np.float32)),
            ('y', np.array([[0]], dtype=np.int32))
        ])
        client_data = [sample_batch]

        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        metrics = round_1[1]

        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        alt_metrics = alt_round_1[1]

        self.assertTrue(
            np.array_equal(state.model.trainable[0],
                           alt_state.model.trainable[0]))
        self.assertTrue(
            np.array_equal(state.model.trainable[1],
                           alt_state.model.trainable[1]))
        self.assertEqual(str(state.model.non_trainable),
                         str(alt_state.model.non_trainable))
        self.assertEqual(state.optimizer_state[0],
                         alt_state.optimizer_state[0])
        self.assertEmpty(state.delta_aggregate_state)
        self.assertEmpty(alt_state.delta_aggregate_state)
        self.assertEmpty(state.model_broadcast_state)
        self.assertEmpty(alt_state.model_broadcast_state)
        self.assertEqual(metrics.sparse_categorical_accuracy,
                         alt_metrics.sparse_categorical_accuracy)
        self.assertEqual(metrics.loss, alt_metrics.loss)
コード例 #3
0
    def test_canonical_form_from_tff_learning_structure_type_spec(self):
        it = test_utils.construct_example_training_comp()
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)

        work_type_spec = cf.work.type_signature

        # This type spec test actually carries the meaning that TFF's vanilla path
        # to canonical form will broadcast and aggregate exactly one copy of the
        # parameters. So the type test below in fact functions as a regression test
        # for the TFF compiler pipeline.
        # pyformat: disable
        expected_type_string = '(<<x=float32[?,2],y=int32[?,1]>*,<<trainable=<float32[2,1],float32[1]>,non_trainable=<>>>> -> <<<<<float32[2,1],float32[1]>,float32>,<float32,float32>,<float32,float32>,<float32>>,<>>,<>>)'
        # pyformat: enable
        self.assertEqual(work_type_spec.compact_representation(),
                         expected_type_string)
コード例 #4
0
  def test_returns_comps_with_federated_aggregate(self):
    iterative_process = test_utils.construct_example_training_comp()
    comp = test_utils.computation_to_building_block(iterative_process.next)
    uri = intrinsic_defs.FEDERATED_AGGREGATE.uri
    before, after = mapreduce_transformations.force_align_and_split_by_intrinsic(
        comp, uri)

    def _predicate(comp):
      return building_block_analysis.is_called_intrinsic(comp, uri)

    self.assertIsInstance(comp, building_blocks.Lambda)
    self.assertGreater(tree_analysis.count(comp, _predicate), 0)
    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertEqual(tree_analysis.count(before, _predicate), 0)
    self.assertEqual(before.parameter_type, comp.parameter_type)
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertEqual(tree_analysis.count(after, _predicate), 0)
    self.assertEqual(after.result.type_signature, comp.result.type_signature)
コード例 #5
0
 def test_example_training_comp_reduces(self):
     training_comp = mapreduce_test_utils.construct_example_training_comp()
     self.assertIsInstance(
         mapreduce_test_utils.computation_to_building_block(
             training_comp.next), building_blocks.Lambda)