Beispiel #1
0
    def _check_top_level_compatibility_with_generic_operators(x, y, op_name):
        """Performs non-recursive check on the types of `x` and `y`."""
        x_compatible = type_analysis.contains_only_types(
            x.type_signature,
            (computation_types.NamedTupleType, computation_types.TensorType,
             computation_types.FederatedType))
        y_compatible = type_analysis.contains_only_types(
            y.type_signature,
            (computation_types.NamedTupleType, computation_types.TensorType,
             computation_types.FederatedType))

        def _make_bad_type_tree_string(index, type_spec):
            return (
                '{} is only implemented for pairs of '
                'arguments both containing only federated, tuple and '
                'tensor types; you have passed argument at index {} of type {} '
                .format(op_name, index, type_spec))

        if not (x_compatible and y_compatible):
            if y_compatible:
                raise TypeError(_make_bad_type_tree_string(
                    0, x.type_signature))
            elif x_compatible:
                raise TypeError(_make_bad_type_tree_string(
                    1, y.type_signature))
            else:
                raise TypeError(
                    '{} is only implemented for pairs of '
                    'arguments both containing only federated, tuple and '
                    'tensor types; both your arguments fail this condition. '
                    'You have passed first argument of type {} '
                    'and second argument of type {}.'.format(
                        op_name, x.type_signature, y.type_signature))

        top_level_mismatch_string = (
            '{} does not accept arguments of type {} and '
            '{}, as they are mismatched at the top level.'.format(
                op_name, x.type_signature, y.type_signature))
        if isinstance(x.type_signature, computation_types.FederatedType):
            if (not isinstance(y.type_signature,
                               computation_types.FederatedType)
                    or x.type_signature.placement != y.type_signature.placement
                    or
                    not type_analysis.is_binary_op_with_upcast_compatible_pair(
                        x.type_signature.member, y.type_signature.member)):
                raise TypeError(top_level_mismatch_string)
        if isinstance(x.type_signature, computation_types.NamedTupleType):
            if type_analysis.is_binary_op_with_upcast_compatible_pair(
                    x.type_signature, y.type_signature):
                return None
            elif not isinstance(y.type_signature,
                                computation_types.NamedTupleType) or dir(
                                    x.type_signature) != dir(y.type_signature):
                raise TypeError(top_level_mismatch_string)
Beispiel #2
0
def is_generic_op_compatible_type(type_spec):
    """Checks `type_spec` against an explicit whitelist for generic operators."""
    if type_spec is None:
        return True
    return type_analysis.contains_only_types(type_spec, (
        computation_types.NamedTupleType,
        computation_types.TensorType,
    ))
Beispiel #3
0
def is_tensorflow_compatible_type(type_spec):
    """Checks `type_spec` against an explicit whitelist for `tf_computation`."""
    if type_spec is None:
        return True
    return type_analysis.contains_only_types(type_spec, (
        computation_types.NamedTupleType,
        computation_types.SequenceType,
        computation_types.TensorType,
    ))
Beispiel #4
0
 def test_returns_false(self, type_signature, types):
     result = type_analysis.contains_only_types(type_signature, types)
     self.assertFalse(result)
Beispiel #5
0
 def _only_tuple_or_tensor(value):
     return type_analysis.contains_only_types(
         value.type_signature, (computation_types.NamedTupleType,
                                computation_types.TensorType))