Beispiel #1
0
    def test_canonical_form_with_learning_structure_does_not_change_execution_of_iterative_process(
            self):
        ip_1 = construct_example_training_comp()
        cf = tff.backends.mapreduce.get_canonical_form_for_iterative_process(
            ip_1)
        ip_2 = tff.backends.mapreduce.get_iterative_process_for_canonical_form(
            cf)

        ip_1.initialize.type_signature.check_equivalent_to(
            ip_2.initialize.type_signature)
        # The next functions type_signatures may not be equal, since we may have
        # appended an empty tuple as client side-channel outputs if none existed.
        ip_1.next.type_signature.parameter.check_equivalent_to(
            ip_2.next.type_signature.parameter)
        ip_1.next.type_signature.result.check_equivalent_to(
            ip_2.next.type_signature.result)

        sample_batch = collections.OrderedDict(
            x=np.array([[1., 1.]], dtype=np.float32),
            y=np.array([[0]], dtype=np.int32),
        )
        client_data = [sample_batch]
        state_1 = ip_1.initialize()
        server_state_1, server_output_1 = ip_1.next(state_1, [client_data])
        server_state_1 = anonymous_tuple.from_container(server_state_1,
                                                        recursive=True)
        server_output_1 = anonymous_tuple.from_container(server_output_1,
                                                         recursive=True)
        server_state_1_arrays = anonymous_tuple.flatten(server_state_1)
        server_output_1_arrays = anonymous_tuple.flatten(server_output_1)
        state_2 = ip_2.initialize()
        server_state_2, server_output_2 = ip_2.next(state_2, [client_data])
        server_state_2_arrays = anonymous_tuple.flatten(server_state_2)
        server_output_2_arrays = anonymous_tuple.flatten(server_output_2)

        self.assertEmpty(server_state_1.delta_aggregate_state)
        self.assertEmpty(server_state_1.model_broadcast_state)
        # Note that we cannot simply use assertEqual because the values may differ
        # due to floating point issues.
        self.assertTrue(
            anonymous_tuple.is_same_structure(server_state_1, server_state_2))
        self.assertTrue(
            anonymous_tuple.is_same_structure(server_output_1,
                                              server_output_2))
        self.assertAllClose(server_state_1_arrays, server_state_2_arrays)
        self.assertAllClose(server_output_1_arrays[:2],
                            server_output_2_arrays[:2])
 def test_is_same_structure_check_types(self):
     self.assertTrue(
         anonymous_tuple.is_same_structure(
             anonymous_tuple.AnonymousTuple([('a', 10)]),
             anonymous_tuple.AnonymousTuple([('a', 20)])))
     self.assertTrue(
         anonymous_tuple.is_same_structure(
             anonymous_tuple.AnonymousTuple([
                 ('a', 10),
                 ('b', anonymous_tuple.AnonymousTuple([('z', 5)])),
             ]),
             anonymous_tuple.AnonymousTuple([
                 ('a', 20),
                 ('b', anonymous_tuple.AnonymousTuple([('z', 50)])),
             ])))
     self.assertFalse(
         anonymous_tuple.is_same_structure(
             anonymous_tuple.AnonymousTuple([('x', {
                 'y': 4
             })]), anonymous_tuple.AnonymousTuple([('x', {
                 'y': 5,
                 'z': 6
             })])))
     self.assertTrue(
         anonymous_tuple.is_same_structure(
             anonymous_tuple.AnonymousTuple([('x', {
                 'y': 5
             })]), anonymous_tuple.AnonymousTuple([('x', {
                 'y': 6
             })])))
     with self.assertRaises(TypeError):
         anonymous_tuple.is_same_structure(
             {'x': 5.0},  # not an AnonymousTuple
             anonymous_tuple.AnonymousTuple([('x', 5.0)]))