Ejemplo n.º 1
0
 def test_long_formatted_with_diff(self):
     int32 = computation_types.TensorType(tf.int32)
     first = computation_types.StructType([(None, int32)] * 20)
     second = computation_types.StructType([(None, int32)] * 21)
     actual = computation_types.type_mismatch_error_message(
         first, second, computation_types.TypeRelation.EQUIVALENT)
     golden.check_string('long_formatted_with_diff.expected', actual)
 def wrapped_func(*args, **kwargs):
     result = fn(*args, **kwargs)
     if result is None:
         raise ValueError('TFF computations may not return `None`. '
                          'Consider instead returning `()`.')
     result_type = type_conversions.infer_type(result)
     if not result_type.is_identical_to(expected_return_type):
         raise TypeError(
             f'Value returned from `{fn.__name__}` did not match asserted type.\n'
             + computation_types.type_mismatch_error_message(
                 result_type,
                 expected_return_type,
                 computation_types.TypeRelation.IDENTICAL,
                 second_is_expected=True))
     return result
Ejemplo n.º 3
0
def check_type(value: Any, type_spec: computation_types.Type):
  """Checks whether `val` is of TFF type `type_spec`.

  Args:
    value: The object to check.
    type_spec: A `computation_types.Type`, the type that `value` is checked
      against.

  Raises:
    TypeError: If the inferred type of `value` is not assignable to `type_spec`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  value_type = type_conversions.infer_type(value)
  if not type_spec.is_assignable_from(value_type):
    raise TypeError(
        computation_types.type_mismatch_error_message(
            value_type,
            type_spec,
            computation_types.TypeRelation.ASSIGNABLE,
            second_is_expected=True))
Ejemplo n.º 4
0
def _select_parameter_mismatch(
    param_type,
    type_desc,
    name,
    secure,
    expected_type=None,
):
    """Throws a `TypeError` indicating a mismatched `select` parameter type."""
    secure_string = '_secure' if secure else ''
    intrinsic_name = f'federated{secure_string}_select'
    message = (
        f'Expected `{intrinsic_name}` parameter `{name}` to be {type_desc}')
    if expected_type is None:
        raise TypeError(f'{message}, found value of type {param_type}')
    else:
        raise TypeError(f'{message}:\n' +
                        computation_types.type_mismatch_error_message(
                            param_type,
                            expected_type,
                            computation_types.TypeRelation.ASSIGNABLE,
                            second_is_expected=True))
Ejemplo n.º 5
0
 def test_container_types_full_repr(self):
     first = computation_types.StructWithPythonType([], list)
     second = computation_types.StructWithPythonType([], tuple)
     actual = computation_types.type_mismatch_error_message(
         first, second, computation_types.TypeRelation.EQUIVALENT)
     golden.check_string('container_types_full_repr.expected', actual)
Ejemplo n.º 6
0
 def test_short_compact_repr(self):
     first = computation_types.TensorType(tf.int32)
     second = computation_types.TensorType(tf.bool)
     actual = computation_types.type_mismatch_error_message(
         first, second, computation_types.TypeRelation.EQUIVALENT)
     golden.check_string('short_compact_repr.expected', actual)
Ejemplo n.º 7
0
    def __init__(self, *, up_to_merge: computation_base.Computation,
                 merge: computation_base.Computation,
                 after_merge: computation_base.Computation):
        if not (up_to_merge.type_signature.result.is_federated()
                and up_to_merge.type_signature.result.placement.is_server()):
            raise UpToMergeTypeError(
                'Expected `up_to_merge` to return a single `tff.SERVER`-placed '
                f'value; found return type {up_to_merge.type_signature.result}.'
            )

        # TFF's StructType assignability relation ensures that an unnamed struct can
        # be assigned to any struct with names.
        expected_merge_param_type = computation_types.StructType([
            (None, up_to_merge.type_signature.result.member),
            (None, up_to_merge.type_signature.result.member)
        ])
        if not merge.type_signature.parameter.is_assignable_from(
                expected_merge_param_type):

            raise MergeTypeNotAssignableError(
                'Type mismatch checking `merge` type signature.\n' +
                computation_types.type_mismatch_error_message(
                    merge.type_signature.parameter,
                    expected_merge_param_type,
                    computation_types.TypeRelation.ASSIGNABLE,
                    second_is_expected=True))
        if not (merge.type_signature.parameter[0].is_assignable_from(
                merge.type_signature.result)
                and merge.type_signature.parameter[1].is_assignable_from(
                    merge.type_signature.result)):
            raise MergeTypeNotAssignableError(
                'Expected `merge` to have result which is assignable to '
                'each element of its parameter tuple; found parameter '
                f'of type: \n{merge.type_signature.parameter}\nAnd result of type: \n'
                f'{merge.type_signature.result}')

        if up_to_merge.type_signature.parameter is not None:
            # TODO(b/147499373): If None arguments were uniformly represented as empty
            # tuples, we could avoid this and related ugly if/else casing.
            expected_after_merge_arg_type = computation_types.StructType([
                (None, up_to_merge.type_signature.parameter),
                (None,
                 computation_types.at_server(merge.type_signature.result)),
            ])
        else:
            expected_after_merge_arg_type = computation_types.at_server(
                merge.type_signature.result)

        after_merge.type_signature.parameter.check_assignable_from(
            expected_after_merge_arg_type)

        def _federated_type_predicate(
                type_signature: computation_types.Type,
                placement: placements.PlacementLiteral) -> bool:
            return (type_signature.is_federated()
                    and type_signature.placement == placement)

        def _moves_clients_to_server_predicate(
                intrinsic: building_blocks.Intrinsic):
            parameter_contains_clients_placement = type_analysis.contains(
                intrinsic.type_signature.parameter,
                lambda x: _federated_type_predicate(x, placements.CLIENTS))
            result_contains_server_placement = type_analysis.contains(
                intrinsic.type_signature.result,
                lambda x: _federated_type_predicate(x, placements.SERVER))
            return (parameter_contains_clients_placement
                    and result_contains_server_placement)

        aggregations = set()

        def _aggregation_predicate(
                comp: building_blocks.ComputationBuildingBlock) -> bool:
            if not comp.is_intrinsic():
                return False
            if not comp.type_signature.is_function():
                return False
            if _moves_clients_to_server_predicate(comp):
                aggregations.add((comp.uri, comp.type_signature))
                return True
            return False

        # We only know how to statically analyze computations which are backed by
        # computation.protos; to avoid opening up a visibility hole that isn't
        # technically necessary here, we prefer to simply skip the static check here
        # for computations which cannot convert themselves to building blocks.
        if hasattr(
                after_merge, 'to_building_block') and tree_analysis.contains(
                    after_merge.to_building_block(), _aggregation_predicate):
            formatted_aggregations = ', '.join(
                '{}: {}'.format(elem[0], elem[1]) for elem in aggregations)
            raise AfterMergeStructureError(
                'Expected `after_merge` to contain no intrinsics '
                'with signatures accepting values at clients and '
                'returning values at server. Found the following '
                f'aggregations: {formatted_aggregations}')

        self.up_to_merge = up_to_merge
        self.merge = merge
        self.after_merge = after_merge
    def __init__(self, *, up_to_merge: computation_base.Computation,
                 merge: computation_base.Computation,
                 after_merge: computation_base.Computation):
        if not (up_to_merge.type_signature.result.is_federated()
                and up_to_merge.type_signature.result.placement.is_server()):
            raise UpToMergeTypeError(
                'Expected `up_to_merge` to return a single `tff.SERVER`-placed '
                f'value; found return type {up_to_merge.type_signature.result}.'
            )

        # TFF's StructType assignability relation ensures that an unnamed struct can
        # be assigned to any struct with names.
        expected_merge_param_type = computation_types.StructType([
            (None, up_to_merge.type_signature.result.member),
            (None, up_to_merge.type_signature.result.member)
        ])
        if not merge.type_signature.parameter.is_assignable_from(
                expected_merge_param_type):

            raise MergeTypeNotAssignableError(
                'Type mismatch checking `merge` type signature.\n' +
                computation_types.type_mismatch_error_message(
                    merge.type_signature.parameter,
                    expected_merge_param_type,
                    computation_types.TypeRelation.ASSIGNABLE,
                    second_is_expected=True))
        if not (merge.type_signature.parameter[0].is_assignable_from(
                merge.type_signature.result)
                and merge.type_signature.parameter[1].is_assignable_from(
                    merge.type_signature.result)):
            raise MergeTypeNotAssignableError(
                'Expected `merge` to have result which is assignable to '
                'each element of its parameter tuple; found parameter '
                f'of type: \n{merge.type_signature.parameter}\nAnd result of type: \n'
                f'{merge.type_signature.result}')

        expected_after_merge_arg_type = computation_types.StructType([
            (None, up_to_merge.type_signature.parameter),
            (None, computation_types.at_server(merge.type_signature.result)),
        ])
        after_merge.type_signature.parameter.check_assignable_from(
            expected_after_merge_arg_type)

        def _federated_type_predicate(
                type_signature: computation_types.Type,
                placement: placements.PlacementLiteral) -> bool:
            return (type_signature.is_federated()
                    and type_signature.placement == placement)

        def _moves_clients_to_server_predicate(
                intrinsic: building_blocks.Intrinsic):
            parameter_contains_clients_placement = type_analysis.contains(
                intrinsic.type_signature.parameter,
                lambda x: _federated_type_predicate(x, placements.CLIENTS))
            result_contains_server_placement = type_analysis.contains(
                intrinsic.type_signature.result,
                lambda x: _federated_type_predicate(x, placements.SERVER))
            return (parameter_contains_clients_placement
                    and result_contains_server_placement)

        aggregations = set()

        def _aggregation_predicate(
                comp: building_blocks.ComputationBuildingBlock) -> bool:
            if not comp.is_intrinsic():
                return False
            if not comp.type_signature.is_function():
                return False
            if _moves_clients_to_server_predicate(comp):
                aggregations.add((comp.uri, comp.type_signature))
                return True
            return False

        if tree_analysis.contains(after_merge.to_building_block(),
                                  _aggregation_predicate):
            formatted_aggregations = ', '.join(
                '{}: {}'.format(elem[0], elem[1]) for elem in aggregations)
            raise AfterMergeStructureError(
                'Expected `after_merge` to contain no intrinsics '
                'with signatures accepting values at clients and '
                'returning values at server. Found the following '
                f'aggregations: {formatted_aggregations}')

        self.up_to_merge = up_to_merge
        self.merge = merge
        self.after_merge = after_merge