def parse_federated_aggregate_argument_types(type_spec): """Verifies and parses `type_spec` into constituents. Args: type_spec: An instance of `computation_types.StructType`. Returns: A tuple of (value_type, zero_type, accumulate_type, merge_type, report_type) for the 5 type constituents. """ py_typecheck.check_type(type_spec, computation_types.StructType) py_typecheck.check_len(type_spec, 5) value_type = type_spec[0] py_typecheck.check_type(value_type, computation_types.FederatedType) item_type = value_type.member zero_type = type_spec[1] accumulate_type = type_spec[2] accumulate_type.check_equivalent_to( type_factory.reduction_op(zero_type, item_type)) merge_type = type_spec[3] merge_type.check_equivalent_to(type_factory.binary_op(zero_type)) report_type = type_spec[4] py_typecheck.check_type(report_type, computation_types.FunctionType) report_type.parameter.check_equivalent_to(zero_type) return value_type, zero_type, accumulate_type, merge_type, report_type
def test_binary_op(self): type_spec = computation_types.TensorType(tf.bool) actual_type = type_factory.binary_op(type_spec) expected_type = computation_types.FunctionType( computation_types.NamedTupleType([type_spec, type_spec]), type_spec) self.assertEqual(actual_type, expected_type)
def create_dummy_intrinsic_def_federated_aggregate(): value = intrinsic_defs.FEDERATED_AGGREGATE type_signature = computation_types.FunctionType([ computation_types.at_clients(tf.float32), tf.float32, type_factory.reduction_op(tf.float32, tf.float32), type_factory.binary_op(tf.float32), computation_types.FunctionType(tf.float32, tf.float32), ], computation_types.at_server(tf.float32)) return value, type_signature
# a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS) # b = generic_reduce(a, zero, merge, SERVER) # c = generic_map(report, b) # return c # # Actual implementations might vary. # # Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER FEDERATED_AGGREGATE = IntrinsicDef( 'FEDERATED_AGGREGATE', 'federated_aggregate', computation_types.FunctionType(parameter=[ type_factory.at_clients(computation_types.AbstractType('T')), computation_types.AbstractType('U'), type_factory.reduction_op(computation_types.AbstractType('U'), computation_types.AbstractType('T')), type_factory.binary_op(computation_types.AbstractType('U')), computation_types.FunctionType(computation_types.AbstractType('U'), computation_types.AbstractType('R')) ], result=type_factory.at_server( computation_types.AbstractType('R')))) # Applies a given function to a value on the server. # # Type signature: <(T->U),T@SERVER> -> U@SERVER FEDERATED_APPLY = IntrinsicDef( 'FEDERATED_APPLY', 'federated_apply', computation_types.FunctionType(parameter=[ computation_types.FunctionType(computation_types.AbstractType('T'), computation_types.AbstractType('U')), type_factory.at_server(computation_types.AbstractType('T')),
def test_binary_op(self): self.assertEqual( str(type_factory.binary_op(tf.bool)), '(<bool,bool> -> bool)')