コード例 #1
0
ファイル: intrinsics_test.py プロジェクト: HaiQW/federated
    def test_federated_aggregate_with_federated_zero_fails(self):
        @tff.federated_computation()
        def build_federated_zero():
            return tff.federated_value(0, tff.SERVER)

        @tff.tf_computation([tf.int32, tf.int32])
        def accumulate(accu, elem):
            return accu + elem

        # The operator to use during the second stage simply adds total and count.
        @tff.tf_computation([tf.int32, tf.int32])
        def merge(x, y):
            return x + y

        # The operator to use during the final stage simply computes the ratio.
        @tff.tf_computation(tf.int32)
        def report(accu):
            return accu

        def foo(x):
            return tff.federated_aggregate(x, build_federated_zero(),
                                           accumulate, merge, report)

        with self.assertRaisesRegex(
                TypeError, 'Expected `zero` to be assignable to type int32, '
                'but was of incompatible type int32@SERVER'):
            tff.federated_computation(foo,
                                      tff.FederatedType(tf.int32, tff.CLIENTS))
コード例 #2
0
ファイル: intrinsics_test.py プロジェクト: HaiQW/federated
    def test_federated_map_injected_zip_fails_different_placements(self):
        def foo(x, y):
            return tff.federated_map(
                tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]),
                [x, y])

        with self.assertRaisesRegex(
                TypeError,
                'You cannot apply federated_map on nested values with mixed '
                'placements.'):

            tff.federated_computation(foo, [
                tff.FederatedType(tf.int32, tff.SERVER),
                tff.FederatedType(tf.int32, tff.CLIENTS)
            ])
コード例 #3
0
    def test_federated_map_injected_zip_fails_different_placements(self):
        def foo(x, y):
            return tff.federated_map(
                tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]),
                [x, y])

        with self.assertRaisesRegex(
                TypeError,
                'The value to be mapped must be a FederatedType or implicitly '
                'convertible to a FederatedType.'):

            tff.federated_computation(foo, [
                tff.FederatedType(tf.int32, tff.SERVER),
                tff.FederatedType(tf.int32, tff.CLIENTS)
            ])
コード例 #4
0
    def test_build_encoded_sum(self, value_constructor, encoder_constructor):
        value = value_constructor(np.random.rand(20))
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_gather_encoder(encoder_constructor(),
                                                value_spec)
        gather_fn = encoding_utils.build_encoded_sum(value, encoder)
        state_type = gather_fn._initialize_fn.type_signature.result
        gather_signature = tff.federated_computation(
            gather_fn._next_fn, tff.FederatedType(state_type, tff.SERVER),
            tff.FederatedType(value_type, tff.CLIENTS),
            tff.FederatedType(tff.to_type(tf.float32),
                              tff.CLIENTS)).type_signature

        self.assertIsInstance(gather_fn, StatefulAggregateFn)
        self.assertEqual(state_type, gather_signature.result[0].member)
        self.assertEqual(tff.SERVER, gather_signature.result[0].placement)
        self.assertEqual(value_type, gather_signature.result[1].member)
        self.assertEqual(tff.SERVER, gather_signature.result[1].placement)
コード例 #5
0
    def test_build_encoded_broadcast(self, value_constructor,
                                     encoder_constructor):
        value = value_constructor(np.random.rand(20))
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_simple_encoder(encoder_constructor(),
                                                value_spec)
        broadcast_fn = encoding_utils.build_encoded_broadcast(value, encoder)
        state_type = broadcast_fn._initialize_fn.type_signature.result
        broadcast_signature = tff.federated_computation(
            broadcast_fn._next_fn,
            tff.FederatedType(
                broadcast_fn._initialize_fn.type_signature.result, tff.SERVER),
            tff.FederatedType(value_type, tff.SERVER)).type_signature

        self.assertIsInstance(broadcast_fn, StatefulBroadcastFn)
        self.assertEqual(state_type, broadcast_signature.result[0].member)
        self.assertEqual(tff.SERVER, broadcast_signature.result[0].placement)
        self.assertEqual(value_type, broadcast_signature.result[1].member)
        self.assertEqual(tff.CLIENTS, broadcast_signature.result[1].placement)