예제 #1
0
    def _check_top_level_compatibility_with_generic_operators(x, y, op_name):
        """Performs non-recursive check on the types of `x` and `y`."""
        x_compatible = type_analysis.contains_only_types(
            x.type_signature,
            (computation_types.NamedTupleType, computation_types.TensorType,
             computation_types.FederatedType))
        y_compatible = type_analysis.contains_only_types(
            y.type_signature,
            (computation_types.NamedTupleType, computation_types.TensorType,
             computation_types.FederatedType))

        def _make_bad_type_tree_string(index, type_spec):
            return (
                '{} is only implemented for pairs of '
                'arguments both containing only federated, tuple and '
                'tensor types; you have passed argument at index {} of type {} '
                .format(op_name, index, type_spec))

        if not (x_compatible and y_compatible):
            if y_compatible:
                raise TypeError(_make_bad_type_tree_string(
                    0, x.type_signature))
            elif x_compatible:
                raise TypeError(_make_bad_type_tree_string(
                    1, y.type_signature))
            else:
                raise TypeError(
                    '{} is only implemented for pairs of '
                    'arguments both containing only federated, tuple and '
                    'tensor types; both your arguments fail this condition. '
                    'You have passed first argument of type {} '
                    'and second argument of type {}.'.format(
                        op_name, x.type_signature, y.type_signature))

        top_level_mismatch_string = (
            '{} does not accept arguments of type {} and '
            '{}, as they are mismatched at the top level.'.format(
                op_name, x.type_signature, y.type_signature))
        if isinstance(x.type_signature, computation_types.FederatedType):
            if (not isinstance(y.type_signature,
                               computation_types.FederatedType)
                    or x.type_signature.placement != y.type_signature.placement
                    or
                    not type_analysis.is_binary_op_with_upcast_compatible_pair(
                        x.type_signature.member, y.type_signature.member)):
                raise TypeError(top_level_mismatch_string)
        if isinstance(x.type_signature, computation_types.NamedTupleType):
            if type_analysis.is_binary_op_with_upcast_compatible_pair(
                    x.type_signature, y.type_signature):
                return None
            elif not isinstance(y.type_signature,
                                computation_types.NamedTupleType) or dir(
                                    x.type_signature) != dir(y.type_signature):
                raise TypeError(top_level_mismatch_string)
예제 #2
0
 def test_fails_compatible_scalar_and_named_tuple(self):
     self.assertFalse(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             computation_types.TensorType(tf.float32),
             computation_types.StructType([
                 ('a', computation_types.TensorType(tf.int32, [2, 2]))
             ])))
예제 #3
0
    def _pack_binary_operator_args(x, y):
        """Packs arguments to binary operator into a single arg."""
        def _only_tuple_or_tensor(value):
            return type_analysis.contains_only_types(
                value.type_signature, (computation_types.NamedTupleType,
                                       computation_types.TensorType))

        if _only_tuple_or_tensor(x) and _only_tuple_or_tensor(y):
            arg = value_impl.ValueImpl(
                building_blocks.Tuple([
                    value_impl.ValueImpl.get_comp(x),
                    value_impl.ValueImpl.get_comp(y)
                ]), context_stack)
        elif (isinstance(x.type_signature, computation_types.FederatedType)
              and isinstance(y.type_signature, computation_types.FederatedType)
              and x.type_signature.placement == y.type_signature.placement):
            if not type_analysis.is_binary_op_with_upcast_compatible_pair(
                    x.type_signature.member, y.type_signature.member):
                raise TypeError(
                    'The members of the federated types {} and {} are not division '
                    'compatible; see `type_analysis.is_binary_op_with_upcast_compatible_pair` '
                    'for more details.'.format(x.type_signature,
                                               y.type_signature))
            packed_arg = value_impl.ValueImpl(
                building_blocks.Tuple([
                    value_impl.ValueImpl.get_comp(x),
                    value_impl.ValueImpl.get_comp(y)
                ]), context_stack)
            arg = intrinsics.federated_zip(packed_arg)
        else:
            raise TypeError
        return arg
예제 #4
0
 def test_fails_named_tuple_type_and_non_scalar_tensor(self):
     self.assertFalse(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             [('a', computation_types.TensorType(tf.int32, [2, 2]))],
             computation_types.TensorType(tf.int32, [2])))
예제 #5
0
 def test_passes_named_tuple_and_compatible_scalar(self):
     self.assertTrue(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             [('a', computation_types.TensorType(tf.int32, [2, 2]))],
             tf.int32))
예제 #6
0
 def test_fails_scalars_different_dtypes(self):
     self.assertFalse(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             tf.int32, tf.float32))
예제 #7
0
 def test_passes_empty_tuples(self):
     self.assertTrue(
         type_analysis.is_binary_op_with_upcast_compatible_pair([], []))
예제 #8
0
 def test_passes_on_none(self):
     self.assertTrue(
         type_analysis.is_binary_op_with_upcast_compatible_pair(None, None))
예제 #9
0
 def test_passes_empty_tuples(self):
     self.assertTrue(
         type_analysis.is_binary_op_with_upcast_compatible_pair(
             computation_types.StructType([]),
             computation_types.StructType([])))
예제 #10
0
 def _generic_op_can_be_applied(x, y):
     return type_analysis.is_binary_op_with_upcast_compatible_pair(
         x.type_signature, y.type_signature) or isinstance(
             x.type_signature, computation_types.FederatedType)
예제 #11
0
 def _generic_op_can_be_applied(x, y):
     return type_analysis.is_binary_op_with_upcast_compatible_pair(
         x.type_signature,
         y.type_signature) or x.type_signature.is_federated()