def test_returns_computation_sequence(self):
        type_signature = computation_types.SequenceType(tf.int32)

        proto = computation_factory.create_lambda_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
    def test_returns_computation_tuple_unnamed(self):
        type_signature = computation_types.StructType([tf.int32, tf.float32])

        proto = computation_factory.create_lambda_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
Exemple #3
0
    async def _compute_intrinsic_federated_aggregate(self, arg):
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._child_executors))
        identity_report = computation_factory.create_lambda_identity(zero_type)
        identity_report_type = type_factory.unary_op(zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.NamedTupleType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), type_factory.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await self.create_selection(arg, index=1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_tuple([v] + list(await asyncio.gather(
                    ex.create_value(zero, zero_type),
                    ex.create_value(accumulate, accumulate_type),
                    ex.create_value(merge, merge_type),
                    ex.create_value(identity_report, identity_report_type)))))
            return await (await ex.create_call(aggr_func, aggr_args)).compute()

        vals = await asyncio.gather(
            *[_child_fn(c, v) for c, v in zip(self._child_executors, val)])
        parent_vals = await asyncio.gather(
            *[self._parent_executor.create_value(v, zero_type) for v in vals])
        parent_merge, parent_report = await asyncio.gather(
            self._parent_executor.create_value(merge, merge_type),
            self._parent_executor.create_value(report, report_type))
        merge_result = parent_vals[0]
        for next_val in parent_vals[1:]:
            merge_result = await self._parent_executor.create_call(
                parent_merge, await
                self._parent_executor.create_tuple([merge_result, next_val]))
        return CompositeValue(
            await self._parent_executor.create_call(parent_report,
                                                    merge_result),
            type_factory.at_server(report_type.result))
 async def _compute_intrinsic_federated_sum(self, arg):
   type_utils.check_federated_type(
       arg.type_signature, placement=placement_literals.CLIENTS)
   zero, plus, identity = tuple(await asyncio.gather(*[
       executor_utils.embed_tf_scalar_constant(self, arg.type_signature.member,
                                               0),
       executor_utils.embed_tf_binary_operator(self, arg.type_signature.member,
                                               tf.add),
       self.create_value(
           computation_factory.create_lambda_identity(
               arg.type_signature.member),
           type_factory.unary_op(arg.type_signature.member))
   ]))
   aggregate_args = await self.create_tuple([arg, zero, plus, plus, identity])
   return await self._compute_intrinsic_federated_aggregate(aggregate_args)
Exemple #5
0
def create_dummy_computation_lambda_identity():
    """Returns a lambda computation and type `(float32 -> float32)`."""
    tensor_type = computation_types.TensorType(tf.float32)
    value = computation_factory.create_lambda_identity(tensor_type)
    type_signature = computation_types.FunctionType(tensor_type, tensor_type)
    return value, type_signature
def create_dummy_computation_lambda_identity():
    """Returns a lambda computation and type `(float32 -> float32)`."""
    type_spec = tf.float32
    value = computation_factory.create_lambda_identity(type_spec)
    type_signature = computation_types.FunctionType(type_spec, type_spec)
    return value, type_signature