Example #1
0
def sequence_map(fn, arg):
    """Maps a TFF sequence `value` pointwise using a given function `fn`.

  This function supports two modes of usage:

  * When applied to a non-federated sequence, it maps individual elements of
    the sequence pointwise. If the supplied `fn` is of type `T->U` and
    the sequence `arg` is of type `T*` (a sequence of `T`-typed elements),
    the result is a sequence of type `U*` (a sequence of `U`-typed elements),
    with each element of the input sequence individually mapped by `fn`.
    In this mode of usage, `sequence_map` behaves like a compuatation with type
    signature `<T->U,T*> -> U*`.

  * When applied to a federated sequence, `sequence_map` behaves as if it were
    individually applied to each member constituent. In this mode of usage, one
    can think of `sequence_map` as a specialized variant of `federated_map` that
    is designed to work with sequences and allows one to
    specify a `fn` that operates at the level of individual elements.
    Indeed, under the hood, when `sequence_map` is invoked on a federated type,
    it injects `federated_map`, thus
    emitting expressions like
    `federated_map(a -> sequence_map(fn, x), arg)`.

  Args:
    fn: A mapping function to apply pointwise to elements of `arg`.
    arg: A value of a TFF type that is either a sequence, or a federated
      sequence.

  Returns:
    A sequence with the result of applying `fn` pointwise to each
    element of `arg`, or if `arg` was federated, a federated sequence
    with the result of invoking `sequence_map` on member sequences locally
    and independently at each location.

  Raises:
    TypeError: If the arguments are not of the appropriate types.
  """
    fn = value_impl.to_value(fn, None)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    arg = value_impl.to_value(arg, None)

    if arg.type_signature.is_sequence():
        comp = building_block_factory.create_sequence_map(fn.comp, arg.comp)
        comp = _bind_comp_as_reference(comp)
        return value_impl.Value(comp)
    elif arg.type_signature.is_federated():
        parameter_type = computation_types.SequenceType(
            fn.type_signature.parameter)
        result_type = computation_types.SequenceType(fn.type_signature.result)
        intrinsic_type = computation_types.FunctionType(
            (fn.type_signature, parameter_type), result_type)
        intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_MAP.uri,
                                              intrinsic_type)
        intrinsic_impl = value_impl.Value(intrinsic)
        local_fn = value_utils.get_curried(intrinsic_impl)(fn)
        return federated_map(local_fn, arg)
    else:
        raise TypeError(
            'Cannot apply `tff.sequence_map()` to a value of type {}.'.format(
                arg.type_signature))
Example #2
0
    def sequence_map(self, fn, arg):
        """Implements `sequence_map` as defined in `api/intrinsics.py`."""
        fn = value_impl.to_value(fn, None, self._context_stack)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        arg = value_impl.to_value(arg, None, self._context_stack)

        if arg.type_signature.is_sequence():
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_sequence_map(fn, arg)
            comp = self._bind_comp_as_reference(comp)
            return value_impl.ValueImpl(comp, self._context_stack)
        elif arg.type_signature.is_federated():
            parameter_type = computation_types.SequenceType(
                fn.type_signature.parameter)
            result_type = computation_types.SequenceType(
                fn.type_signature.result)
            intrinsic_type = computation_types.FunctionType(
                (fn.type_signature, parameter_type), result_type)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            local_fn = value_utils.get_curried(intrinsic_impl)(fn)
            return self.federated_map(local_fn, arg)
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(arg.type_signature))
Example #3
0
    def test_get_curried(self):
        add_numbers = value_impl.Value(
            building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(
                    computations.tf_computation(
                        lambda a, b: tf.add(a, b),  # pylint: disable=unnecessary-lambda
                        [tf.int32, tf.int32]))))

        curried = value_utils.get_curried(add_numbers)
        self.assertEqual(str(curried.type_signature),
                         '(int32 -> (int32 -> int32))')

        comp, _ = tree_transformations.uniquify_compiled_computation_names(
            curried.comp)
        self.assertEqual(comp.compact_representation(),
                         '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
Example #4
0
    def test_get_curried(self):
        operand_type = computation_types.TensorType(tf.int32)
        computation_proto, type_signature = tensorflow_computation_factory.create_binary_operator(
            tf.add, operand_type, operand_type)
        building_block = building_blocks.CompiledComputation(
            proto=computation_proto,
            name='test',
            type_signature=type_signature)
        add_numbers = value_impl.Value(building_block)

        curried = value_utils.get_curried(add_numbers)

        self.assertEqual(curried.type_signature.compact_representation(),
                         '(int32 -> (int32 -> int32))')
        self.assertEqual(curried.comp.compact_representation(),
                         '(arg0 -> (arg1 -> comp#test(<arg0,arg1>)))')