def _check_top_level_compatibility_with_generic_operators(x, y, op_name): """Performs non-recursive check on the types of `x` and `y`.""" x_compatible = type_analysis.contains_only_types( x.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) y_compatible = type_analysis.contains_only_types( y.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) def _make_bad_type_tree_string(index, type_spec): return ( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; you have passed argument at index {} of type {} ' .format(op_name, index, type_spec)) if not (x_compatible and y_compatible): if y_compatible: raise TypeError(_make_bad_type_tree_string( 0, x.type_signature)) elif x_compatible: raise TypeError(_make_bad_type_tree_string( 1, y.type_signature)) else: raise TypeError( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; both your arguments fail this condition. ' 'You have passed first argument of type {} ' 'and second argument of type {}.'.format( op_name, x.type_signature, y.type_signature)) top_level_mismatch_string = ( '{} does not accept arguments of type {} and ' '{}, as they are mismatched at the top level.'.format( op_name, x.type_signature, y.type_signature)) if isinstance(x.type_signature, computation_types.FederatedType): if (not isinstance(y.type_signature, computation_types.FederatedType) or x.type_signature.placement != y.type_signature.placement or not type_analysis.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member)): raise TypeError(top_level_mismatch_string) if isinstance(x.type_signature, computation_types.NamedTupleType): if type_analysis.is_binary_op_with_upcast_compatible_pair( x.type_signature, y.type_signature): return None elif not isinstance(y.type_signature, computation_types.NamedTupleType) or dir( x.type_signature) != dir(y.type_signature): raise TypeError(top_level_mismatch_string)
def is_generic_op_compatible_type(type_spec): """Checks `type_spec` against an explicit whitelist for generic operators.""" if type_spec is None: return True return type_analysis.contains_only_types(type_spec, ( computation_types.NamedTupleType, computation_types.TensorType, ))
def is_tensorflow_compatible_type(type_spec): """Checks `type_spec` against an explicit whitelist for `tf_computation`.""" if type_spec is None: return True return type_analysis.contains_only_types(type_spec, ( computation_types.NamedTupleType, computation_types.SequenceType, computation_types.TensorType, ))
def test_returns_false(self, type_signature, types): result = type_analysis.contains_only_types(type_signature, types) self.assertFalse(result)
def _only_tuple_or_tensor(value): return type_analysis.contains_only_types( value.type_signature, (computation_types.NamedTupleType, computation_types.TensorType))