def sequence_sum(self, value):
        """Implements `sequence_sum` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        value = value_impl.to_value(value, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        if not type_utils.is_sum_compatible(element_type):
            raise TypeError(
                'The value type {} is not compatible with the sum operator.'.
                format(value.type_signature.member))

        if isinstance(value.type_signature, computation_types.SequenceType):
            value = value_impl.ValueImpl.get_comp(value)
            return computation_constructing_utils.create_sequence_sum(value)
        elif isinstance(value.type_signature, computation_types.FederatedType):
            intrinsic_type = computation_types.FunctionType(
                value.type_signature.member,
                value.type_signature.member.element)
            intrinsic = computation_building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(intrinsic_impl, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(intrinsic_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_sum()` to a value of type {}.'.
                format(value.type_signature))
示例#2
0
 def test_returns_federated_sum(self):
     value_type = computation_types.SequenceType(tf.int32)
     value = computation_building_blocks.Data('v', value_type)
     comp = computation_constructing_utils.create_sequence_sum(value)
     self.assertEqual(comp.tff_repr, 'sequence_sum(v)')
     self.assertEqual(str(comp.type_signature), 'int32')
示例#3
0
 def test_raises_type_error_with_none_value(self):
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_sequence_sum(None)