Example #1
0
 def test_serialize_deserialize_federated_at_clients(self):
   x = [10, 20]
   x_type = type_factory.at_clients(tf.int32)
   value_proto, value_type = executor_service_utils.serialize_value(x, x_type)
   self.assertIsInstance(value_proto, executor_pb2.Value)
   self.assertEqual(str(value_type), '{int32}@CLIENTS')
   y, type_spec = executor_service_utils.deserialize_value(value_proto)
   self.assertEqual(str(type_spec), str(x_type))
   self.assertEqual(y, [10, 20])
    def test_returns_value_with_federated_type_at_clients_all_equal(self):
        value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
        type_signature = type_factory.at_clients(tf.float32, all_equal=True)
        value = federated_resolving_strategy.FederatedResolvingStrategyValue(
            value, type_signature)

        result = self.run_sync(value.compute())

        self.assertEqual(result, 10.0)
    async def compute_federated_zip_at_clients(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        py_typecheck.check_type(arg.type_signature,
                                computation_types.StructType)
        py_typecheck.check_len(arg.type_signature, 2)
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 2)
        keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)]
        vals = [arg.internal_representation[n] for n in [0, 1]]
        types = [arg.type_signature[n] for n in [0, 1]]
        for n in [0, 1]:
            type_analysis.check_federated_type(
                types[n], placement=placement_literals.CLIENTS)
            types[n] = type_factory.at_clients(types[n].member)
            py_typecheck.check_type(vals[n], list)
            py_typecheck.check_len(vals[n], len(self._target_executors))
        item_type = computation_types.StructType([
            ((keys[n], types[n].member) if keys[n] else types[n].member)
            for n in [0, 1]
        ])
        result_type = type_factory.at_clients(item_type)
        zip_type = computation_types.FunctionType(
            computation_types.StructType([
                ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1]
            ]), result_type)
        zip_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type)

        async def _child_fn(ex, x, y):
            py_typecheck.check_type(x, executor_value_base.ExecutorValue)
            py_typecheck.check_type(y, executor_value_base.ExecutorValue)
            return await ex.create_call(
                await ex.create_value(zip_comp, zip_type), await
                ex.create_struct(
                    anonymous_tuple.AnonymousTuple([(keys[0], x),
                                                    (keys[1], y)])))

        result = await asyncio.gather(*[
            _child_fn(c, x, y)
            for c, x, y in zip(self._target_executors, vals[0], vals[1])
        ])
        return FederatedComposingStrategyValue(result, result_type)
def create_dummy_intrinsic_def_federated_aggregate():
    value = intrinsic_defs.FEDERATED_AGGREGATE
    type_signature = computation_types.FunctionType([
        type_factory.at_clients(tf.float32),
        tf.float32,
        type_factory.reduction_op(tf.float32, tf.float32),
        type_factory.binary_op(tf.float32),
        computation_types.FunctionType(tf.float32, tf.float32),
    ], type_factory.at_server(tf.float32))
    return value, type_signature
    def test_federated_mean(self):
        @computations.federated_computation(type_factory.at_clients(tf.float32)
                                            )
        def comp(x):
            return intrinsics.federated_mean(x)

        executor, num_clients = _create_test_executor()
        arg = [float(x + 1) for x in range(num_clients)]
        result = _invoke(executor, comp, arg)
        self.assertEqual(result, 6.5)
Example #6
0
    def test_execution_with_inferred_clients_larger_than_fanout(
            self, executor_factory_fn):
        @computations.federated_computation(type_factory.at_clients(tf.int32))
        def foo(x):
            return intrinsics.federated_sum(x)

        executor = executor_factory_fn(max_fanout=3)
        with executor_test_utils.install_executor(executor):
            result = foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

        self.assertEqual(result, 55)
Example #7
0
  def test_returns_value_with_federated_type_at_clients(self):
    value = [
        eager_tf_executor.EagerValue(10.0, None, tf.float32),
        eager_tf_executor.EagerValue(11.0, None, tf.float32),
        eager_tf_executor.EagerValue(12.0, None, tf.float32),
    ]
    type_signature = type_factory.at_clients(tf.float32)
    value = federating_executor.FederatingExecutorValue(value, type_signature)

    result = self.run_sync(value.compute())

    self.assertEqual(result, [10.0, 11.0, 12.0])
    async def _map(self, arg, all_equal=None):
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 2)
        fn_type = arg.type_signature[0]
        py_typecheck.check_type(fn_type, computation_types.FunctionType)
        val_type = arg.type_signature[1]
        py_typecheck.check_type(val_type, computation_types.FederatedType)
        if all_equal is None:
            all_equal = val_type.all_equal
        elif all_equal and not val_type.all_equal:
            raise ValueError(
                'Cannot map a non-all_equal argument into an all_equal result.'
            )
        fn = arg.internal_representation[0]
        py_typecheck.check_type(fn, pb.Computation)
        val = arg.internal_representation[1]
        py_typecheck.check_type(val, list)

        map_type = computation_types.FunctionType(
            [fn_type, type_factory.at_clients(fn_type.parameter)],
            type_factory.at_clients(fn_type.result))
        map_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_MAP, map_type)

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            fn_val = await ex.create_value(fn, fn_type)
            map_val, map_arg = await asyncio.gather(
                ex.create_value(map_comp, map_type),
                ex.create_struct([fn_val, v]))
            return await ex.create_call(map_val, map_arg)

        result_vals = await asyncio.gather(
            *[_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        federated_type = computation_types.FederatedType(fn_type.result,
                                                         val_type.placement,
                                                         all_equal=all_equal)
        return FederatedComposingStrategyValue(result_vals, federated_type)
Example #9
0
  def test_returns_value_with_unplaced_type_and_clients(self, executor):
    value, type_signature = executor_test_utils.create_dummy_value_unplaced()

    value = self.run_sync(executor.create_value(value, type_signature))
    result = self.run_sync(
        executor_utils.compute_intrinsic_federated_value(
            executor, value, placement_literals.CLIENTS))

    self.assertIsInstance(result, executor_value_base.ExecutorValue)
    expected_type = type_factory.at_clients(type_signature, all_equal=True)
    self.assertEqual(result.type_signature.compact_representation(),
                     expected_type.compact_representation())
    actual_result = self.run_sync(result.compute())
    self.assertEqual(actual_result, 10.0)
    def test_recovers_from_raising(self):
        class _RaisingExecutor(eager_tf_executor.EagerTFExecutor):
            """An executor which can be configured to raise on `create_value`."""
            def __init__(self):
                self._should_raise = True
                super().__init__()

            def stop_raising(self):
                self._should_raise = False

            async def create_value(self, *args, **kwargs):
                if self._should_raise:
                    raise AssertionError
                return await super().create_value(*args, **kwargs)

        raising_executors = [_RaisingExecutor() for _ in range(2)]

        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {
                placement_literals.SERVER: _create_worker_stack(),
                placement_literals.CLIENTS: raising_executors,
            })
        federating_ex = federating_executor.FederatingExecutor(
            factory, _create_worker_stack())

        raising_stacks = [federating_ex for _ in range(3)]

        executor = _create_middle_stack([
            _create_middle_stack(raising_stacks),
            _create_middle_stack(raising_stacks),
        ])

        @computations.federated_computation(type_factory.at_clients(tf.float32)
                                            )
        def comp(x):
            return intrinsics.federated_mean(x)

        # 2 clients per worker stack * 3 worker stacks * 2 middle stacks
        num_clients = 12
        arg = [float(x + 1) for x in range(num_clients)]

        with self.assertRaises(AssertionError):
            _invoke(executor, comp, arg)

        for ex in raising_executors:
            ex.stop_raising()

        result = _invoke(executor, comp, arg)
        self.assertEqual(result, 6.5)
Example #11
0
    def test_changing_cardinalities_across_calls(self):
        @computations.federated_computation(type_factory.at_clients(tf.int32))
        def comp(x):
            return x

        five_ints = list(range(5))
        ten_ints = list(range(10))

        executor = executor_stacks.local_executor_factory()
        with executor_test_utils.install_executor(executor):
            five = comp(five_ints)
            ten = comp(ten_ints)

        self.assertEqual(five, five_ints)
        self.assertEqual(ten, ten_ints)
Example #12
0
  def test_returns_value_with_federated_type_at_server(self, executor,
                                                       num_clients):
    del num_clients  # Unused.
    value, type_signature = executor_test_utils.create_dummy_value_at_server()

    value = self.run_sync(executor.create_value(value, type_signature))
    result = self.run_sync(
        executor_utils.compute_intrinsic_federated_broadcast(executor, value))

    self.assertIsInstance(result, executor_value_base.ExecutorValue)
    expected_type = type_factory.at_clients(
        type_signature.member, all_equal=True)
    self.assertEqual(result.type_signature.compact_representation(),
                     expected_type.compact_representation())
    actual_result = self.run_sync(result.compute())
    self.assertEqual(actual_result, 10.0)
Example #13
0
    def test_get_size_info(self, num_clients):
        @computations.federated_computation(
            type_factory.at_clients(computation_types.SequenceType(
                tf.float32)), type_factory.at_server(tf.float32))
        def comp(temperatures, threshold):
            client_data = [
                temperatures,
                intrinsics.federated_broadcast(threshold)
            ]
            result_map = intrinsics.federated_map(
                count_over, intrinsics.federated_zip(client_data))
            count_map = intrinsics.federated_map(count_total, temperatures)
            return intrinsics.federated_mean(result_map, count_map)

        sizing_factory = executor_stacks.sizing_executor_factory(
            num_clients=num_clients)
        sizing_context = execution_context.ExecutionContext(sizing_factory)
        with get_context_stack.get_context_stack().install(sizing_context):
            to_float = lambda x: tf.cast(x, tf.float32)
            temperatures = [tf.data.Dataset.range(10).map(to_float)
                            ] * num_clients
            threshold = 15.0
            comp(temperatures, threshold)

            # Each client receives a tf.float32 and uploads two tf.float32 values.
            expected_broadcast_bits = [num_clients * 32]
            expected_aggregate_bits = [num_clients * 32 * 2]
            expected_broadcast_history = {
                (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients
            }
            expected_aggregate_history = {
                (('CLIENTS', num_clients), ):
                [[1, tf.float32]] * num_clients * 2
            }

            size_info = sizing_factory.get_size_info()

            self.assertEqual(expected_broadcast_history,
                             size_info.broadcast_history)
            self.assertEqual(expected_aggregate_history,
                             size_info.aggregate_history)
            self.assertEqual(expected_broadcast_bits, size_info.broadcast_bits)
            self.assertEqual(expected_aggregate_bits, size_info.aggregate_bits)
def _create_tff_parallel_clients_with_dataset_reduce():
    @tf.function
    def reduce_fn(x, y):
        return x + y

    @tf.function
    def dataset_reduce_fn(ds, initial_val):
        return ds.reduce(initial_val, reduce_fn)

    @computations.tf_computation(computation_types.SequenceType(tf.int64))
    def dataset_reduce_fn_wrapper(ds):
        initial_val = tf.Variable(np.int64(1.0))
        return dataset_reduce_fn(ds, initial_val)

    @computations.federated_computation(
        type_factory.at_clients(computation_types.SequenceType(tf.int64)))
    def parallel_client_run(client_datasets):
        return intrinsics.federated_map(dataset_reduce_fn_wrapper,
                                        client_datasets)

    return parallel_client_run
Example #15
0
 async def _compute_intrinsic_federated_mean(self, arg):
     member_type = arg.type_signature.member
     ones = await self.create_value(
         1, type_factory.at_clients(member_type, all_equal=True))
     totals = (await self._compute_intrinsic_federated_sum(
         await self._compute_intrinsic_federated_zip_at_clients(
             await self.create_tuple([arg, ones])))).internal_representation
     py_typecheck.check_type(totals, executor_value_base.ExecutorValue)
     fed_sum, count = tuple(await asyncio.gather(
         self._parent_executor.create_selection(totals, index=0),
         self._parent_executor.create_selection(totals, index=1)))
     count_val = await count.compute()
     factor, multiply = tuple(await asyncio.gather(*[
         executor_utils.embed_tf_scalar_constant(
             self._parent_executor, member_type, float(1.0 / count_val)),
         executor_utils.embed_tf_binary_operator(self._parent_executor,
                                                 member_type, tf.multiply)
     ]))
     multiply_arg = await self._parent_executor.create_tuple(
         [fed_sum, factor])
     result = await self._parent_executor.create_call(
         multiply, multiply_arg)
     return CompositeValue(result, type_factory.at_server(member_type))
def create_dummy_value_at_clients_all_equal():
    """Returns a Python value and federated type at clients and all equal."""
    value = 10.0
    type_signature = type_factory.at_clients(tf.float32, all_equal=True)
    return value, type_signature
def create_dummy_value_at_clients(number_of_clients: int = 3):
    """Returns a Python value and federated type at clients."""
    value = [float(x) for x in range(10, number_of_clients + 10)]
    type_signature = type_factory.at_clients(tf.float32)
    return value, type_signature
def create_dummy_intrinsic_def_federated_value_at_clients():
    value = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS
    type_signature = computation_types.FunctionType(
        tf.float32, type_factory.at_clients(tf.float32, all_equal=True))
    return value, type_signature
def create_dummy_intrinsic_def_federated_sum():
    value = intrinsic_defs.FEDERATED_SUM
    type_signature = computation_types.FunctionType(
        type_factory.at_clients(tf.float32),
        type_factory.at_server(tf.float32))
    return value, type_signature
def create_dummy_intrinsic_def_federated_eval_at_clients():
    value = intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS
    type_signature = computation_types.FunctionType(
        computation_types.FunctionType(None, tf.float32),
        type_factory.at_clients(tf.float32))
    return value, type_signature
def create_dummy_intrinsic_def_federated_collect():
    value = intrinsic_defs.FEDERATED_COLLECT
    type_signature = computation_types.FunctionType(
        type_factory.at_clients(tf.float32),
        type_factory.at_server(computation_types.SequenceType(tf.float32)))
    return value, type_signature
def create_dummy_intrinsic_def_federated_broadcast():
    value = intrinsic_defs.FEDERATED_BROADCAST
    type_signature = computation_types.FunctionType(
        type_factory.at_server(tf.float32),
        type_factory.at_clients(tf.float32, all_equal=True))
    return value, type_signature
Example #23
0
 def test_at_clients(self):
     type_spec = computation_types.TensorType(tf.bool)
     actual_type = type_factory.at_clients(type_spec)
     expected_type = computation_types.FederatedType(
         type_spec, placement_literals.CLIENTS)
     self.assertEqual(actual_type, expected_type)
Example #24
0
# intrinsics defined above, as follows.
#
# @federated_computation
# def federated_aggregate(x, zero, accumulate, merge, report):
#   a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS)
#   b = generic_reduce(a, zero, merge, SERVER)
#   c = generic_map(report, b)
#   return c
#
# Actual implementations might vary.
#
# Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER
FEDERATED_AGGREGATE = IntrinsicDef(
    'FEDERATED_AGGREGATE', 'federated_aggregate',
    computation_types.FunctionType(parameter=[
        type_factory.at_clients(computation_types.AbstractType('T')),
        computation_types.AbstractType('U'),
        type_factory.reduction_op(computation_types.AbstractType('U'),
                                  computation_types.AbstractType('T')),
        type_factory.binary_op(computation_types.AbstractType('U')),
        computation_types.FunctionType(computation_types.AbstractType('U'),
                                       computation_types.AbstractType('R'))
    ],
                                   result=type_factory.at_server(
                                       computation_types.AbstractType('R'))))

# Applies a given function to a value on the server.
#
# Type signature: <(T->U),T@SERVER> -> U@SERVER
FEDERATED_APPLY = IntrinsicDef(
    'FEDERATED_APPLY', 'federated_apply',
Example #25
0
async def compute_intrinsic_federated_weighted_mean(
    executor: executor_base.Executor, arg: executor_value_base.ExecutorValue
) -> executor_value_base.ExecutorValue:
  """Computes a federated weighted mean on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The argument to embedded in `executor`.

  Returns:
    The result embedded in `executor`.
  """
  type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
      arg.type_signature)
  zip1_type = computation_types.FunctionType(
      computation_types.StructType([
          type_factory.at_clients(arg.type_signature[0].member),
          type_factory.at_clients(arg.type_signature[1].member)
      ]),
      type_factory.at_clients(
          computation_types.StructType(
              [arg.type_signature[0].member, arg.type_signature[1].member])))

  multiply_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast(
      zip1_type.result.member, tf.multiply)

  map_type = computation_types.FunctionType(
      computation_types.StructType(
          [multiply_blk.type_signature, zip1_type.result]),
      type_factory.at_clients(multiply_blk.type_signature.result))

  sum1_type = computation_types.FunctionType(
      type_factory.at_clients(map_type.result.member),
      type_factory.at_server(map_type.result.member))

  sum2_type = computation_types.FunctionType(
      type_factory.at_clients(arg.type_signature[1].member),
      type_factory.at_server(arg.type_signature[1].member))

  zip2_type = computation_types.FunctionType(
      computation_types.StructType([sum1_type.result, sum2_type.result]),
      type_factory.at_server(
          computation_types.StructType(
              [sum1_type.result.member, sum2_type.result.member])))

  divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast(
      zip2_type.result.member, tf.divide)

  async def _compute_multiply_fn():
    return await executor.create_value(multiply_blk.proto,
                                       multiply_blk.type_signature)

  async def _compute_multiply_arg():
    zip1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS,
                                      zip1_type)
    zip_fn = await executor.create_value(zip1_comp, zip1_type)
    return await executor.create_call(zip_fn, arg)

  async def _compute_product_fn():
    map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type)
    return await executor.create_value(map_comp, map_type)

  async def _compute_product_arg():
    multiply_fn, multiply_arg = await asyncio.gather(_compute_multiply_fn(),
                                                     _compute_multiply_arg())
    return await executor.create_struct((multiply_fn, multiply_arg))

  async def _compute_products():
    product_fn, product_arg = await asyncio.gather(_compute_product_fn(),
                                                   _compute_product_arg())
    return await executor.create_call(product_fn, product_arg)

  async def _compute_total_weight():
    sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type)
    sum2_fn, sum2_arg = await asyncio.gather(
        executor.create_value(sum2_comp, sum2_type),
        executor.create_selection(arg, index=1))
    return await executor.create_call(sum2_fn, sum2_arg)

  async def _compute_sum_of_products():
    sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type)
    sum1_fn, products = await asyncio.gather(
        executor.create_value(sum1_comp, sum1_type), _compute_products())
    return await executor.create_call(sum1_fn, products)

  async def _compute_zip2_fn():
    zip2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_SERVER,
                                      zip2_type)
    return await executor.create_value(zip2_comp, zip2_type)

  async def _compute_zip2_arg():
    sum_of_products, total_weight = await asyncio.gather(
        _compute_sum_of_products(), _compute_total_weight())
    return await executor.create_struct([sum_of_products, total_weight])

  async def _compute_divide_fn():
    return await executor.create_value(divide_blk.proto,
                                       divide_blk.type_signature)

  async def _compute_divide_arg():
    zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(),
                                           _compute_zip2_arg())
    return await executor.create_call(zip_fn, zip_arg)

  async def _compute_apply_fn():
    apply_type = computation_types.FunctionType(
        computation_types.StructType(
            [divide_blk.type_signature, zip2_type.result]),
        type_factory.at_server(divide_blk.type_signature.result))
    apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                       apply_type)
    return await executor.create_value(apply_comp, apply_type)

  async def _compute_apply_arg():
    divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(),
                                                 _compute_divide_arg())
    return await executor.create_struct([divide_fn, divide_arg])

  async def _compute_divided():
    apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(),
                                               _compute_apply_arg())
    return await executor.create_call(apply_fn, apply_arg)

  return await _compute_divided()
async def compute_intrinsic_federated_weighted_mean(
    executor: executor_base.Executor, arg: executor_value_base.ExecutorValue
) -> executor_value_base.ExecutorValue:
    """Computes a federated weighted mean on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The argument to embedded in `executor`.

  Returns:
    The result embedded in `executor`.
  """
    type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
        arg.type_signature)
    zip1_type = computation_types.FunctionType(
        computation_types.NamedTupleType([
            type_factory.at_clients(arg.type_signature[0].member),
            type_factory.at_clients(arg.type_signature[1].member)
        ]),
        type_factory.at_clients(
            computation_types.NamedTupleType(
                [arg.type_signature[0].member, arg.type_signature[1].member])))
    zip1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS,
                                      zip1_type)
    zipped_arg = await executor.create_call(
        await executor.create_value(zip1_comp, zip1_type), arg)

    # TODO(b/134543154): Replace with something that produces a section of
    # plain TensorFlow code instead of constructing a lambda (so that this
    # can be executed directly on top of a plain TensorFlow-based executor).
    multiply_blk = building_block_factory.create_binary_operator_with_upcast(
        zipped_arg.type_signature.member, tf.multiply)

    map_type = computation_types.FunctionType(
        computation_types.NamedTupleType(
            [multiply_blk.type_signature, zipped_arg.type_signature]),
        type_factory.at_clients(multiply_blk.type_signature.result))
    map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type)
    products = await executor.create_call(
        await executor.create_value(map_comp, map_type), await
        executor.create_tuple([
            await executor.create_value(multiply_blk.proto,
                                        multiply_blk.type_signature),
            zipped_arg
        ]))
    sum1_type = computation_types.FunctionType(
        type_factory.at_clients(products.type_signature.member),
        type_factory.at_server(products.type_signature.member))
    sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type)
    sum_of_products = await executor.create_call(
        await executor.create_value(sum1_comp, sum1_type), products)
    sum2_type = computation_types.FunctionType(
        type_factory.at_clients(arg.type_signature[1].member),
        type_factory.at_server(arg.type_signature[1].member))
    sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type)
    total_weight = await executor.create_call(
        *(await asyncio.gather(executor.create_value(sum2_comp, sum2_type),
                               executor.create_selection(arg, index=1))))
    zip2_type = computation_types.FunctionType(
        computation_types.NamedTupleType(
            [sum_of_products.type_signature, total_weight.type_signature]),
        type_factory.at_server(
            computation_types.NamedTupleType([
                sum_of_products.type_signature.member,
                total_weight.type_signature.member
            ])))
    zip2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_ZIP_AT_SERVER,
                                      zip2_type)
    divide_arg = await executor.create_call(*(await asyncio.gather(
        executor.create_value(zip2_comp, zip2_type),
        executor.create_tuple([sum_of_products, total_weight]))))
    divide_blk = building_block_factory.create_binary_operator_with_upcast(
        divide_arg.type_signature.member, tf.divide)
    apply_type = computation_types.FunctionType(
        computation_types.NamedTupleType(
            [divide_blk.type_signature, divide_arg.type_signature]),
        type_factory.at_server(divide_blk.type_signature.result))
    apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                       apply_type)
    return await executor.create_call(*(await asyncio.gather(
        executor.create_value(apply_comp, apply_type),
        executor.create_tuple([
            await executor.create_value(divide_blk.proto,
                                        divide_blk.type_signature), divide_arg
        ]))))
Example #27
0
 def test_raises_type_error(self):
   type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
       computation_types.StructType([type_factory.at_clients(tf.float32)] * 2))
   with self.assertRaises(TypeError):
     type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
         computation_types.StructType([type_factory.at_clients(tf.int32)] * 2))
 def test_at_clients(self):
   self.assertEqual(str(type_factory.at_clients(tf.bool)), '{bool}@CLIENTS')