예제 #1
0
    def test_raises_mismatched_dataset_comp_return_type_and_iterproc_sequence_type(
            self):
        iterproc = _create_federated_int_dataset_identity_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation_with_iterative_process(
                float_dataset_computation, iterproc)
예제 #2
0
    def test_raises_non_computation(self):
        fn = lambda _: []
        iterproc = _create_federated_int_dataset_identity_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation_with_iterative_process(
                fn, iterproc)
예제 #3
0
    def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()

        old_param_type = iterproc.next.type_signature.parameter

        new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]]

        @computations.federated_computation(
            computation_types.StructType(new_param_elements))
        def new_next(param):
            return iterproc.next([param[0], param[2]])

        iterproc_with_dataset_as_third_elem = iterative_process.IterativeProcess(
            iterproc.initialize, new_next)
        expected_new_next_type_signature = computation_types.FunctionType([
            computation_types.FederatedType(tf.int64, placements.SERVER),
            tf.int32,
            computation_types.FederatedType(tf.string, placements.CLIENTS)
        ], computation_types.FederatedType(tf.int64, placements.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process(
            int_dataset_computation, iterproc_with_dataset_as_third_elem)

        self.assertTrue(
            expected_new_next_type_signature.is_equivalent_to(
                new_iterproc.next.type_signature))
예제 #4
0
    def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()
        expected_new_next_type_signature = computation_types.FunctionType(
            collections.OrderedDict(
                server_state=computation_types.FederatedType(
                    tf.int64, placements.SERVER),
                client_data=computation_types.FederatedType(
                    tf.string, placements.CLIENTS)),
            computation_types.FederatedType(tf.int64, placements.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process(
            int_dataset_computation, iterproc)

        expected_new_next_type_signature.check_equivalent_to(
            new_iterproc.next.type_signature)
예제 #5
0
    def test_mutates_iterproc_with_parameter_assignable_from_result(self):
        iterproc = _create_stateless_int_vector_unknown_dim_dataset_reduction_iterative_process(
        )
        expected_new_next_type_signature = computation_types.FunctionType(
            collections.OrderedDict(
                server_state=computation_types.FederatedType(
                    computation_types.TensorType(tf.int64, shape=[None]),
                    placements.SERVER),
                client_data=computation_types.FederatedType(
                    tf.string, placements.CLIENTS)),
            computation_types.FederatedType(
                computation_types.TensorType(tf.int64, shape=[1]),
                placements.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process(
            vector_int_dataset_computation, iterproc)

        expected_new_next_type_signature.check_equivalent_to(
            new_iterproc.next.type_signature)
예제 #6
0
    def test_raises_iterproc_if_dataset_is_returned_by_init(self):
        iterproc = _create_federated_int_dataset_identity_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation_with_iterative_process(
                int_dataset_computation, iterproc)
예제 #7
0
    def test_raises_iterative_process_no_dataset_parameter(self):
        iterproc = _create_whimsy_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation_with_iterative_process(
                int_dataset_computation, iterproc)
예제 #8
0
    def test_raises_non_iterative_process(self):
        non_iterproc = lambda x: x

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation_with_iterative_process(
                int_dataset_computation, non_iterproc)