def sequence_map(self, fn, arg):
    """Implements `sequence_map` as defined in `api/intrinsics.py`."""
    fn = value_impl.to_value(fn, None, self._context_stack)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    arg = value_impl.to_value(arg, None, self._context_stack)

    if isinstance(arg.type_signature, computation_types.SequenceType):
      fn = value_impl.ValueImpl.get_comp(fn)
      arg = value_impl.ValueImpl.get_comp(arg)
      return value_impl.ValueImpl(
          building_block_factory.create_sequence_map(fn, arg),
          self._context_stack)
    elif isinstance(arg.type_signature, computation_types.FederatedType):
      parameter_type = computation_types.SequenceType(
          fn.type_signature.parameter)
      result_type = computation_types.SequenceType(fn.type_signature.result)
      intrinsic_type = computation_types.FunctionType(
          (fn.type_signature, parameter_type), result_type)
      intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_MAP.uri,
                                            intrinsic_type)
      intrinsic_impl = value_impl.ValueImpl(intrinsic, self._context_stack)
      local_fn = value_utils.get_curried(intrinsic_impl)(fn)

      if arg.type_signature.placement in [
          placements.SERVER, placements.CLIENTS
      ]:
        return self.federated_map(local_fn, arg)
      else:
        raise TypeError('Unsupported placement {}.'.format(
            arg.type_signature.placement))
    else:
      raise TypeError(
          'Cannot apply `tff.sequence_map()` to a value of type {}.'.format(
              arg.type_signature))
Example #2
0
  def test_get_curried(self):
    add_numbers = value_impl.ValueImpl(
        building_blocks.ComputationBuildingBlock.from_proto(
            computation_impl.ComputationImpl.get_proto(
                computations.tf_computation(tf.add, [tf.int32, tf.int32]))),
        _context_stack)

    curried = value_utils.get_curried(add_numbers)
    self.assertEqual(str(curried.type_signature), '(int32 -> (int32 -> int32))')

    comp, _ = tree_transformations.uniquify_compiled_computation_names(
        value_impl.ValueImpl.get_comp(curried))
    self.assertEqual(comp.compact_representation(),
                     '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
    def sequence_map(self, mapping_fn, value):
        """Implements `sequence_map` as defined in `api/intrinsics.py`.

    Args:
      mapping_fn: As in `api/intrinsics.py`.
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        mapping_fn = value_impl.to_value(mapping_fn, None, self._context_stack)
        py_typecheck.check_type(mapping_fn.type_signature,
                                computation_types.FunctionType)
        sequence_map_intrinsic = value_impl.ValueImpl(
            computation_building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri,
                computation_types.FunctionType(
                    [
                        mapping_fn.type_signature,
                        computation_types.SequenceType(
                            mapping_fn.type_signature.parameter)
                    ],
                    computation_types.SequenceType(
                        mapping_fn.type_signature.result))),
            self._context_stack)
        value = value_impl.to_value(value, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            return sequence_map_intrinsic(mapping_fn, value)
        elif isinstance(value.type_signature, computation_types.FederatedType):
            local_func = value_utils.get_curried(sequence_map_intrinsic)(
                mapping_fn)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(local_func, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(local_func, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    str(value.type_signature.placement)))
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(str(value.type_signature)))