def test_long_formatted_with_diff(self):
   int32 = computation_types.TensorType(tf.int32)
   first = computation_types.StructType([(None, int32)] * 20)
   second = computation_types.StructType([(None, int32)] * 21)
   actual = computation_types.type_mismatch_error_message(
       first, second, computation_types.TypeRelation.EQUIVALENT)
   golden.check_string('long_formatted_with_diff.expected', actual)
 def wrapped_func(*args, **kwargs):
   result = fn(*args, **kwargs)
   if result is None:
     raise ValueError('TFF computations may not return `None`. '
                      'Consider instead returning `()`.')
   result_type = type_conversions.infer_type(result)
   if not result_type.is_identical_to(expected_return_type):
     raise TypeError(
         f'Value returned from `{fn.__name__}` did not match asserted type.\n'
         + computation_types.type_mismatch_error_message(
             result_type,
             expected_return_type,
             computation_types.TypeRelation.IDENTICAL,
             second_is_expected=True))
   return result
Exemple #3
0
def _select_parameter_mismatch(
    param_type,
    type_desc,
    name,
    secure,
    expected_type=None,
):
    """Throws a `TypeError` indicating a mismatched `select` parameter type."""
    secure_string = '_secure' if secure else ''
    intrinsic_name = f'federated{secure_string}_select'
    message = (
        f'Expected `{intrinsic_name}` parameter `{name}` to be {type_desc}')
    if expected_type is None:
        raise TypeError(f'{message}, found value of type {param_type}')
    else:
        raise TypeError(f'{message}:\n' +
                        computation_types.type_mismatch_error_message(
                            param_type,
                            expected_type,
                            computation_types.TypeRelation.ASSIGNABLE,
                            second_is_expected=True))
    def federated_reduce(self, value, zero, op):
        """Implements `federated_reduce` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.CLIENTS,
                                                   'value to be reduced')

        zero = value_impl.to_value(zero, None, self._context_stack)
        if type_analysis.contains_federated_types(zero.type_signature):
            raise TypeError(
                '`zero` may not contain a federated type, found type:\n' +
                str(zero.type_signature))

        op = value_impl.to_value(
            op,
            None,
            self._context_stack,
            parameter_type_hint=computation_types.StructType(
                [zero.type_signature, value.type_signature.member]))
        op.type_signature.check_function()
        if not op.type_signature.result.is_assignable_from(
                zero.type_signature):
            raise TypeError(
                '`zero` must be assignable to the result type from `op`:\n',
                computation_types.type_mismatch_error_message(
                    zero.type_signature, op.type_signature.result,
                    computation_types.TypeRelation.ASSIGNABLE))
        op_type_expected = type_factory.reduction_op(
            op.type_signature.result, value.type_signature.member)
        if not op_type_expected.is_assignable_from(op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        comp = building_block_factory.create_federated_reduce(value, zero, op)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
 def test_container_types_full_repr(self):
   first = computation_types.StructWithPythonType([], list)
   second = computation_types.StructWithPythonType([], tuple)
   actual = computation_types.type_mismatch_error_message(
       first, second, computation_types.TypeRelation.EQUIVALENT)
   golden.check_string('container_types_full_repr.expected', actual)
 def test_short_compact_repr(self):
   first = computation_types.TensorType(tf.int32)
   second = computation_types.TensorType(tf.bool)
   actual = computation_types.type_mismatch_error_message(
       first, second, computation_types.TypeRelation.EQUIVALENT)
   golden.check_string('short_compact_repr.expected', actual)