def _check_top_level_compatibility_with_generic_operators(x, y, op_name): """Performs non-recursive check on the types of `x` and `y`.""" x_compatible = type_utils.type_tree_contains_only( x.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) y_compatible = type_utils.type_tree_contains_only( y.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) def _make_bad_type_tree_string(index, type_spec): return ( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; you have passed argument at index {} of type {} ' .format(op_name, index, type_spec)) if not (x_compatible and y_compatible): if y_compatible: raise TypeError(_make_bad_type_tree_string( 0, x.type_signature)) elif x_compatible: raise TypeError(_make_bad_type_tree_string( 1, y.type_signature)) else: raise TypeError( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; both your arguments fail this condition. ' 'You have passed first argument of type {} ' 'and second argument of type {}.'.format( op_name, x.type_signature, y.type_signature)) top_level_mismatch_string = ( '{} does not accept arguments of type {} and ' '{}, as they are mismatched at the top level.'.format( op_name, x.type_signature, y.type_signature)) if isinstance(x.type_signature, computation_types.FederatedType): if (not isinstance(y.type_signature, computation_types.FederatedType) or x.type_signature.placement != y.type_signature.placement or not type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member)): raise TypeError(top_level_mismatch_string) if isinstance(x.type_signature, computation_types.NamedTupleType): if type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature, y.type_signature): return None elif not isinstance(y.type_signature, computation_types.NamedTupleType) or dir( x.type_signature) != dir(y.type_signature): raise TypeError(top_level_mismatch_string)
def _check_generic_operator_type(type_spec): """Checks that `type_spec` can be the signature of args to a generic op.""" if not type_utils.type_tree_contains_only( type_spec, (computation_types.FederatedType, computation_types.NamedTupleType, computation_types.TensorType)): raise TypeError( 'Generic operators are only implemented for ' 'arguments both containing only federated, tuple and ' 'tensor types; you have passed an argument of type {} '.format( type_spec)) if not (isinstance(type_spec, computation_types.NamedTupleType) and len(type_spec) == 2): raise TypeError( 'We are trying to construct a generic operator declaring argument that ' 'is not a two-tuple, the type {}.'.format(type_spec)) if not type_utils.is_binary_op_with_upcast_compatible_pair( type_spec[0], type_spec[1]): raise TypeError('The two-tuple you have passed in is incompatible with ' 'upcasted binary operators. You have passed the tuple ' 'type {}, which fails the check that the two members of ' 'the tuple are either the same type, or the second is a ' 'scalar with the same dtype as the leaves of the first. ' 'See `type_utils.is_binary_op_with_upcast_compatible_pair` ' 'for more details.'.format(type_spec))
def construct_generic_constant(type_spec, scalar_value): """Creates constant for a combination of federated, tuple and tensor types. Args: type_spec: Instance of `computation_types.Type` containing only federated, tuple or tensor types for which we wish to construct a generic constant. May also be something convertible to a `computation_types.Type` via `computation_types.to_type`. scalar_value: The scalar value we wish this constant to have. Returns: Instance of `computation_building_blocks.ComputationBuildingBlock` representing `scalar_value` packed into `type_spec`. Raises: TypeError: If types don't match their specification in the args section. Notice validation of consistency of `type_spec` with `scalar_value` is not the rsponsibility of this function. """ type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) inferred_scalar_value_type = type_utils.infer_type(scalar_value) if (not isinstance(inferred_scalar_value_type, computation_types.TensorType) or inferred_scalar_value_type.shape != tf.TensorShape(())): raise TypeError('Must pass a scalar value to ' '`construct_tensorflow_constant`; encountered a value ' '{}'.format(scalar_value)) if not type_utils.type_tree_contains_only( type_spec, (computation_types.FederatedType, computation_types.NamedTupleType, computation_types.TensorType)): raise TypeError if type_utils.type_tree_contains_only( type_spec, (computation_types.NamedTupleType, computation_types.TensorType)): return computation_constructing_utils.construct_tensorflow_constant( type_spec, scalar_value) elif isinstance(type_spec, computation_types.FederatedType): unplaced_zero = computation_constructing_utils.construct_tensorflow_constant( type_spec.member, scalar_value) if type_spec.placement == placement_literals.CLIENTS: placement_fn_type = computation_types.FunctionType( type_spec.member, computation_types.FederatedType( type_spec.member, type_spec.placement, all_equal=True)) placement_function = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, placement_fn_type) elif type_spec.placement == placement_literals.SERVER: placement_fn_type = computation_types.FunctionType( type_spec.member, computation_types.FederatedType( type_spec.member, type_spec.placement, all_equal=True)) placement_function = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, placement_fn_type) return computation_building_blocks.Call(placement_function, unplaced_zero) elif isinstance(type_spec, computation_types.NamedTupleType): elements = [] for k in range(len(type_spec)): elements.append(construct_generic_constant(type_spec[k], scalar_value)) names = [name for name, _ in anonymous_tuple.to_elements(type_spec)] packed_elements = computation_building_blocks.Tuple(elements) named_tuple = computation_constructing_utils.create_named_tuple( packed_elements, names) return named_tuple else: raise ValueError( 'The type_spec {} has slipped through all our ' 'generic constant cases, and failed to raise.'.format(type_spec))
def _only_tuple_or_tensor(value): return type_utils.type_tree_contains_only( value.type_signature, (computation_types.NamedTupleType, computation_types.TensorType))