def test_raises_mismatched_dataset_comp_return_type_and_iterproc_sequence_type(
            self):
        iterproc = _create_federated_int_dataset_identity_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation(
                float_dataset_computation, iterproc)
    def test_raises_non_computation(self):

        fn = lambda _: []
        iterproc = _create_federated_int_dataset_identity_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation(
                fn, iterproc)
    def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()

        old_param_type = iterproc.next.type_signature.parameter

        new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]]

        @computations.federated_computation(
            computation_types.StructType(new_param_elements))
        def new_next(param):
            return iterproc.next([param[0], param[2]])

        iterproc_with_dataset_as_third_elem = iterative_process.IterativeProcess(
            iterproc.initialize, new_next)
        expected_new_next_type_signature = computation_types.FunctionType([
            computation_types.FederatedType(
                tf.int64, placement_literals.SERVER), tf.int32,
            computation_types.FederatedType(tf.string,
                                            placement_literals.CLIENTS)
        ], computation_types.FederatedType(tf.int64,
                                           placement_literals.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation(
            int_dataset_computation, iterproc_with_dataset_as_third_elem)

        self.assertTrue(
            expected_new_next_type_signature.is_equivalent_to(
                new_iterproc.next.type_signature))
Ejemplo n.º 4
0
    def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()
        expected_new_next_type_signature = tff.FunctionType([
            tff.FederatedType(tf.int64, tff.SERVER),
            tff.FederatedType(tf.string, tff.CLIENTS)
        ], tff.FederatedType(tf.int64, tff.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation(
            int_dataset_computation, iterproc)

        self.assertTrue(
            expected_new_next_type_signature.is_equivalent_to(
                new_iterproc.next.type_signature))
Ejemplo n.º 5
0
    def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()
        expected_new_next_type_signature = computation_types.FunctionType(
            collections.OrderedDict(
                server_state=computation_types.FederatedType(
                    tf.int64, placement_literals.SERVER),
                client_data=computation_types.FederatedType(
                    tf.string, placement_literals.CLIENTS)),
            computation_types.FederatedType(tf.int64,
                                            placement_literals.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation(
            int_dataset_computation, iterproc)

        expected_new_next_type_signature.check_equivalent_to(
            new_iterproc.next.type_signature)
    def test_raises_iterproc_if_dataset_is_returned_by_init(self):
        iterproc = _create_federated_int_dataset_identity_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation(
                int_dataset_computation, iterproc)
    def test_raises_iterative_process_no_dataset_parameter(self):
        iterproc = _create_dummy_iterative_process()

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation(
                int_dataset_computation, iterproc)
    def test_raises_non_iterative_process(self):
        non_iterproc = lambda x: x

        with self.assertRaises(TypeError):
            iterative_process_compositions.compose_dataset_computation(
                int_dataset_computation, non_iterproc)