Ejemplo n.º 1
0
def parse_federated_aggregate_argument_types(type_spec):
    """Verifies and parses `type_spec` into constituents.

  Args:
    type_spec: An instance of `computation_types.StructType`.

  Returns:
    A tuple of (value_type, zero_type, accumulate_type, merge_type, report_type)
    for the 5 type constituents.
  """
    py_typecheck.check_type(type_spec, computation_types.StructType)
    py_typecheck.check_len(type_spec, 5)
    value_type = type_spec[0]
    py_typecheck.check_type(value_type, computation_types.FederatedType)
    item_type = value_type.member
    zero_type = type_spec[1]
    accumulate_type = type_spec[2]
    accumulate_type.check_equivalent_to(
        type_factory.reduction_op(zero_type, item_type))
    merge_type = type_spec[3]
    merge_type.check_equivalent_to(type_factory.binary_op(zero_type))
    report_type = type_spec[4]
    py_typecheck.check_type(report_type, computation_types.FunctionType)
    report_type.parameter.check_equivalent_to(zero_type)
    return value_type, zero_type, accumulate_type, merge_type, report_type
Ejemplo n.º 2
0
 def test_binary_op(self):
     type_spec = computation_types.TensorType(tf.bool)
     actual_type = type_factory.binary_op(type_spec)
     expected_type = computation_types.FunctionType(
         computation_types.NamedTupleType([type_spec, type_spec]),
         type_spec)
     self.assertEqual(actual_type, expected_type)
Ejemplo n.º 3
0
def create_dummy_intrinsic_def_federated_aggregate():
    value = intrinsic_defs.FEDERATED_AGGREGATE
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(tf.float32),
        tf.float32,
        type_factory.reduction_op(tf.float32, tf.float32),
        type_factory.binary_op(tf.float32),
        computation_types.FunctionType(tf.float32, tf.float32),
    ], computation_types.at_server(tf.float32))
    return value, type_signature
Ejemplo n.º 4
0
#   a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS)
#   b = generic_reduce(a, zero, merge, SERVER)
#   c = generic_map(report, b)
#   return c
#
# Actual implementations might vary.
#
# Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER
FEDERATED_AGGREGATE = IntrinsicDef(
    'FEDERATED_AGGREGATE', 'federated_aggregate',
    computation_types.FunctionType(parameter=[
        type_factory.at_clients(computation_types.AbstractType('T')),
        computation_types.AbstractType('U'),
        type_factory.reduction_op(computation_types.AbstractType('U'),
                                  computation_types.AbstractType('T')),
        type_factory.binary_op(computation_types.AbstractType('U')),
        computation_types.FunctionType(computation_types.AbstractType('U'),
                                       computation_types.AbstractType('R'))
    ],
                                   result=type_factory.at_server(
                                       computation_types.AbstractType('R'))))

# Applies a given function to a value on the server.
#
# Type signature: <(T->U),T@SERVER> -> U@SERVER
FEDERATED_APPLY = IntrinsicDef(
    'FEDERATED_APPLY', 'federated_apply',
    computation_types.FunctionType(parameter=[
        computation_types.FunctionType(computation_types.AbstractType('T'),
                                       computation_types.AbstractType('U')),
        type_factory.at_server(computation_types.AbstractType('T')),
 def test_binary_op(self):
   self.assertEqual(
       str(type_factory.binary_op(tf.bool)), '(<bool,bool> -> bool)')