def sequence_map(fn, arg): """Maps a TFF sequence `value` pointwise using a given function `fn`. This function supports two modes of usage: * When applied to a non-federated sequence, it maps individual elements of the sequence pointwise. If the supplied `fn` is of type `T->U` and the sequence `arg` is of type `T*` (a sequence of `T`-typed elements), the result is a sequence of type `U*` (a sequence of `U`-typed elements), with each element of the input sequence individually mapped by `fn`. In this mode of usage, `sequence_map` behaves like a compuatation with type signature `<T->U,T*> -> U*`. * When applied to a federated sequence, `sequence_map` behaves as if it were individually applied to each member constituent. In this mode of usage, one can think of `sequence_map` as a specialized variant of `federated_map` that is designed to work with sequences and allows one to specify a `fn` that operates at the level of individual elements. Indeed, under the hood, when `sequence_map` is invoked on a federated type, it injects `federated_map`, thus emitting expressions like `federated_map(a -> sequence_map(fn, x), arg)`. Args: fn: A mapping function to apply pointwise to elements of `arg`. arg: A value of a TFF type that is either a sequence, or a federated sequence. Returns: A sequence with the result of applying `fn` pointwise to each element of `arg`, or if `arg` was federated, a federated sequence with the result of invoking `sequence_map` on member sequences locally and independently at each location. Raises: TypeError: If the arguments are not of the appropriate types. """ fn = value_impl.to_value(fn, None) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) arg = value_impl.to_value(arg, None) if arg.type_signature.is_sequence(): comp = building_block_factory.create_sequence_map(fn.comp, arg.comp) comp = _bind_comp_as_reference(comp) return value_impl.Value(comp) elif arg.type_signature.is_federated(): parameter_type = computation_types.SequenceType( fn.type_signature.parameter) result_type = computation_types.SequenceType(fn.type_signature.result) intrinsic_type = computation_types.FunctionType( (fn.type_signature, parameter_type), result_type) intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type) intrinsic_impl = value_impl.Value(intrinsic) local_fn = value_utils.get_curried(intrinsic_impl)(fn) return federated_map(local_fn, arg) else: raise TypeError( 'Cannot apply `tff.sequence_map()` to a value of type {}.'.format( arg.type_signature))
def sequence_map(self, fn, arg): """Implements `sequence_map` as defined in `api/intrinsics.py`.""" fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) arg = value_impl.to_value(arg, None, self._context_stack) if arg.type_signature.is_sequence(): fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_sequence_map(fn, arg) comp = self._bind_comp_as_reference(comp) return value_impl.ValueImpl(comp, self._context_stack) elif arg.type_signature.is_federated(): parameter_type = computation_types.SequenceType( fn.type_signature.parameter) result_type = computation_types.SequenceType( fn.type_signature.result) intrinsic_type = computation_types.FunctionType( (fn.type_signature, parameter_type), result_type) intrinsic = building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type) intrinsic_impl = value_impl.ValueImpl(intrinsic, self._context_stack) local_fn = value_utils.get_curried(intrinsic_impl)(fn) return self.federated_map(local_fn, arg) else: raise TypeError( 'Cannot apply `tff.sequence_map()` to a value of type {}.'. format(arg.type_signature))
def test_get_curried(self): add_numbers = value_impl.Value( building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation( lambda a, b: tf.add(a, b), # pylint: disable=unnecessary-lambda [tf.int32, tf.int32])))) curried = value_utils.get_curried(add_numbers) self.assertEqual(str(curried.type_signature), '(int32 -> (int32 -> int32))') comp, _ = tree_transformations.uniquify_compiled_computation_names( curried.comp) self.assertEqual(comp.compact_representation(), '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
def test_get_curried(self): operand_type = computation_types.TensorType(tf.int32) computation_proto, type_signature = tensorflow_computation_factory.create_binary_operator( tf.add, operand_type, operand_type) building_block = building_blocks.CompiledComputation( proto=computation_proto, name='test', type_signature=type_signature) add_numbers = value_impl.Value(building_block) curried = value_utils.get_curried(add_numbers) self.assertEqual(curried.type_signature.compact_representation(), '(int32 -> (int32 -> int32))') self.assertEqual(curried.comp.compact_representation(), '(arg0 -> (arg1 -> comp#test(<arg0,arg1>)))')