def sequence_sum(self, value): """Implements `sequence_sum` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element if not type_utils.is_sum_compatible(element_type): raise TypeError( 'The value type {} is not compatible with the sum operator.'. format(value.type_signature.member)) if isinstance(value.type_signature, computation_types.SequenceType): value = value_impl.ValueImpl.get_comp(value) return computation_constructing_utils.create_sequence_sum(value) elif isinstance(value.type_signature, computation_types.FederatedType): intrinsic_type = computation_types.FunctionType( value.type_signature.member, value.type_signature.member.element) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type) intrinsic_impl = value_impl.ValueImpl(intrinsic, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(intrinsic_impl, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(intrinsic_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement)) else: raise TypeError( 'Cannot apply `tff.sequence_sum()` to a value of type {}.'. format(value.type_signature))
def test_returns_federated_sum(self): value_type = computation_types.SequenceType(tf.int32) value = computation_building_blocks.Data('v', value_type) comp = computation_constructing_utils.create_sequence_sum(value) self.assertEqual(comp.tff_repr, 'sequence_sum(v)') self.assertEqual(str(comp.type_signature), 'int32')
def test_raises_type_error_with_none_value(self): with self.assertRaises(TypeError): computation_constructing_utils.create_sequence_sum(None)