def _check_top_level_compatibility_with_generic_operators(x, y, op_name): """Performs non-recursive check on the types of `x` and `y`.""" x_compatible = type_analysis.contains_only_types( x.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) y_compatible = type_analysis.contains_only_types( y.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) def _make_bad_type_tree_string(index, type_spec): return ( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; you have passed argument at index {} of type {} ' .format(op_name, index, type_spec)) if not (x_compatible and y_compatible): if y_compatible: raise TypeError(_make_bad_type_tree_string( 0, x.type_signature)) elif x_compatible: raise TypeError(_make_bad_type_tree_string( 1, y.type_signature)) else: raise TypeError( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; both your arguments fail this condition. ' 'You have passed first argument of type {} ' 'and second argument of type {}.'.format( op_name, x.type_signature, y.type_signature)) top_level_mismatch_string = ( '{} does not accept arguments of type {} and ' '{}, as they are mismatched at the top level.'.format( op_name, x.type_signature, y.type_signature)) if isinstance(x.type_signature, computation_types.FederatedType): if (not isinstance(y.type_signature, computation_types.FederatedType) or x.type_signature.placement != y.type_signature.placement or not type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member)): raise TypeError(top_level_mismatch_string) if isinstance(x.type_signature, computation_types.NamedTupleType): if type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature, y.type_signature): return None elif not isinstance(y.type_signature, computation_types.NamedTupleType) or dir( x.type_signature) != dir(y.type_signature): raise TypeError(top_level_mismatch_string)
def _pack_binary_operator_args(x, y): """Packs arguments to binary operator into a single arg.""" def _only_tuple_or_tensor(value): return type_analysis.contains_only_types( value.type_signature, (computation_types.NamedTupleType, computation_types.TensorType)) if _only_tuple_or_tensor(x) and _only_tuple_or_tensor(y): arg = value_impl.ValueImpl( building_blocks.Tuple([ value_impl.ValueImpl.get_comp(x), value_impl.ValueImpl.get_comp(y) ]), context_stack) elif (isinstance(x.type_signature, computation_types.FederatedType) and isinstance(y.type_signature, computation_types.FederatedType) and x.type_signature.placement == y.type_signature.placement): if not type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member): raise TypeError( 'The members of the federated types {} and {} are not division ' 'compatible; see `type_utils.is_binary_op_with_upcast_compatible_pair` ' 'for more details.'.format(x.type_signature, y.type_signature)) packed_arg = value_impl.ValueImpl( building_blocks.Tuple([ value_impl.ValueImpl.get_comp(x), value_impl.ValueImpl.get_comp(y) ]), context_stack) arg = intrinsics.federated_zip(packed_arg) else: raise TypeError return arg
def _check_generic_operator_type(type_spec): """Checks that `type_spec` can be the signature of args to a generic op.""" if not type_utils.type_tree_contains_only( type_spec, (computation_types.FederatedType, computation_types.NamedTupleType, computation_types.TensorType)): raise TypeError( 'Generic operators are only implemented for ' 'arguments both containing only federated, tuple and ' 'tensor types; you have passed an argument of type {} '.format( type_spec)) if not (isinstance(type_spec, computation_types.NamedTupleType) and len(type_spec) == 2): raise TypeError( 'We are trying to construct a generic operator declaring argument that ' 'is not a two-tuple, the type {}.'.format(type_spec)) if not type_utils.is_binary_op_with_upcast_compatible_pair( type_spec[0], type_spec[1]): raise TypeError('The two-tuple you have passed in is incompatible with ' 'upcasted binary operators. You have passed the tuple ' 'type {}, which fails the check that the two members of ' 'the tuple are either the same type, or the second is a ' 'scalar with the same dtype as the leaves of the first. ' 'See `type_utils.is_binary_op_with_upcast_compatible_pair` ' 'for more details.'.format(type_spec))
def _generic_op_can_be_applied(x, y): return type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature, y.type_signature) or isinstance( x.type_signature, computation_types.FederatedType)
def test_fails_named_tuple_type_and_non_scalar_tensor(self): self.assertFalse( type_utils.is_binary_op_with_upcast_compatible_pair( [('a', computation_types.TensorType(tf.int32, [2, 2]))], computation_types.TensorType(tf.int32, [2])))
def test_passes_named_tuple_and_compatible_scalar(self): self.assertTrue( type_utils.is_binary_op_with_upcast_compatible_pair( [('a', computation_types.TensorType(tf.int32, [2, 2]))], tf.int32))
def test_fails_scalars_different_dtypes(self): self.assertFalse( type_utils.is_binary_op_with_upcast_compatible_pair( tf.int32, tf.float32))
def test_passes_empty_tuples(self): self.assertTrue( type_utils.is_binary_op_with_upcast_compatible_pair([], []))
def test_passes_on_none(self): self.assertTrue( type_utils.is_binary_op_with_upcast_compatible_pair(None, None))