Exemplo n.º 1
0
def construct_map_or_apply(fn, arg):
  """Injects intrinsic to allow application of `fn` to federated `arg`.

  Args:
    fn: An instance of `computation_building_blocks.ComputationBuildingBlock` of
      functional type to be wrapped with intrinsic in order to call on `arg`.
    arg: `computation_building_blocks.ComputationBuildingBlock` instance of
      federated type for which to construct intrinsic in order to call `fn` on
      `arg`. `member` of `type_signature` of `arg` must be assignable to
      `parameter` of `type_signature` of `fn`.

  Returns:
    Returns a `computation_building_blocks.Intrinsic` which can call
    `fn` on `arg`.
  """
  py_typecheck.check_type(fn,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(arg,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
  py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
  type_utils.check_assignable_from(fn.type_signature.parameter,
                                   arg.type_signature.member)
  if arg.type_signature.placement == placement_literals.SERVER:
    result_type = computation_types.FederatedType(fn.type_signature.result,
                                                  arg.type_signature.placement,
                                                  arg.type_signature.all_equal)
    intrinsic_type = computation_types.FunctionType(
        [fn.type_signature, arg.type_signature], result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type)
    tup = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, tup)
  elif arg.type_signature.placement == placement_literals.CLIENTS:
    return create_federated_map(fn, arg)
Exemplo n.º 2
0
 async def create_selection(self, source, index=None, name=None):
     py_typecheck.check_type(source, CachedValue)
     py_typecheck.check_type(source.type_signature,
                             computation_types.NamedTupleType)
     source_val = await source.target_future
     if index is not None:
         py_typecheck.check_none(name)
         identifier_str = '{}[{}]'.format(source.identifier, index)
         type_spec = source.type_signature[index]
     else:
         py_typecheck.check_not_none(name)
         identifier_str = '{}.{}'.format(source.identifier, name)
         type_spec = getattr(source.type_signature, name)
     identifier = CachedValueIdentifier(identifier_str)
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_selection(source_val,
                                                    index=index,
                                                    name=name))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     target_value = await cached_value.target_future
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 3
0
 async def create_call(self, comp, arg=None):
     py_typecheck.check_type(comp, CachedValue)
     py_typecheck.check_type(comp.type_signature,
                             computation_types.FunctionType)
     to_gather = [comp.target_future]
     if arg is not None:
         py_typecheck.check_type(arg, CachedValue)
         type_utils.check_assignable_from(comp.type_signature.parameter,
                                          arg.type_signature)
         to_gather.append(arg.target_future)
         identifier_str = '{}({})'.format(comp.identifier, arg.identifier)
     else:
         identifier_str = '{}()'.format(comp.identifier)
     gathered = await asyncio.gather(*to_gather)
     type_spec = comp.type_signature.result
     identifier = CachedValueIdentifier(identifier_str)
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_call(*gathered))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     target_value = await cached_value.target_future
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 4
0
 async def create_call(self, comp, arg=None):
   py_typecheck.check_type(comp, LambdaExecutorValue)
   py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
   param_type = comp.type_signature.parameter
   if param_type is not None:
     py_typecheck.check_type(arg, LambdaExecutorValue)
     if not type_utils.is_assignable_from(param_type, arg.type_signature):
       arg_type = type_utils.get_argument_type(arg.type_signature)
       type_utils.check_assignable_from(param_type, arg_type)
       return await self.create_call(comp, await self.create_call(arg))
   else:
     py_typecheck.check_none(arg)
   comp_repr = comp.internal_representation
   if isinstance(comp_repr, executor_value_base.ExecutorValue):
     return LambdaExecutorValue(await self._target_executor.create_call(
         comp_repr, await self._delegate(arg) if arg is not None else None))
   elif callable(comp_repr):
     return await comp_repr(arg)
   else:
     # An anonymous tuple could not possibly have a functional type signature,
     # so this is the only case left to handle.
     py_typecheck.check_type(comp_repr, pb.Computation)
     eval_result = await self._evaluate(comp_repr, comp.scope)
     py_typecheck.check_type(eval_result, LambdaExecutorValue)
     if arg is not None:
       py_typecheck.check_type(eval_result.type_signature,
                               computation_types.FunctionType)
       type_utils.check_assignable_from(eval_result.type_signature.parameter,
                                        arg.type_signature)
       return await self.create_call(eval_result, arg)
     elif isinstance(eval_result.type_signature,
                     computation_types.FunctionType):
       return await self.create_call(eval_result, arg=None)
     else:
       return eval_result
Exemplo n.º 5
0
  async def create_call(self, comp, arg=None):
    py_typecheck.check_type(comp, LambdaExecutorValue)
    py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
    param_type = comp.type_signature.parameter
    if param_type is not None:
      py_typecheck.check_type(arg, LambdaExecutorValue)
      arg_type = arg.type_signature  # pytype: disable=attribute-error
      if not type_utils.is_assignable_from(param_type, arg_type):
        adjusted_arg_type = arg_type
        if isinstance(arg_type, computation_types.FunctionType):
          # We may have a non-functional type embedded as a no-arg function.
          adjusted_arg_type = type_utils.get_argument_type(arg_type)

        # HACK: The second (`or`) check covers the case where a no-arg lambda
        # was passed to the lambda executor as a value.
        # `get_function_type` is used above to transform non-function
        # types into no-arg lambda types, which are unwrapped by
        # `get_argument_type` above. Since our value of type `( -> T)`
        # then has an `adjusted_arg_type` of type `T`, the above
        # check will fail, so we fall back to checking the unadjusted
        # type. Note: this will permit some values passed in as type `T`
        # to be used as values of type `( -> T)`.
        if not (type_utils.is_assignable_from(param_type, adjusted_arg_type) or
                type_utils.is_assignable_from(param_type, arg_type)):
          raise TypeError('LambdaExecutor asked to create call with '
                          'incompatible type specifications. Function '
                          'takes an argument of type {}, but was supplied '
                          'an argument of type {} (or possibly {})'.format(
                              param_type, adjusted_arg_type, arg_type))
        arg = await self.create_call(arg)
        return await self.create_call(comp, arg)
    else:
      py_typecheck.check_none(arg)
    comp_repr = comp.internal_representation
    if isinstance(comp_repr, executor_value_base.ExecutorValue):
      delegated_arg = await self._delegate(arg) if arg is not None else None
      return LambdaExecutorValue(await self._target_executor.create_call(
          comp_repr, delegated_arg))
    elif callable(comp_repr):
      return await comp_repr(arg)
    else:
      # An anonymous tuple could not possibly have a functional type signature,
      # so this is the only case left to handle.
      py_typecheck.check_type(comp_repr, pb.Computation)
      eval_result = await self._evaluate(comp_repr, comp.scope)
      py_typecheck.check_type(eval_result, LambdaExecutorValue)
      if arg is not None:
        py_typecheck.check_type(eval_result.type_signature,
                                computation_types.FunctionType)
        type_utils.check_assignable_from(eval_result.type_signature.parameter,
                                         arg.type_signature)
        return await self.create_call(eval_result, arg)
      elif isinstance(eval_result.type_signature,
                      computation_types.FunctionType):
        return await self.create_call(eval_result, arg=None)
      else:
        return eval_result
Exemplo n.º 6
0
  def invoke(self, fn, arg):
    comp = self._compile(fn)
    cardinalities = {}
    root_context = ComputationContext(cardinalities=cardinalities)
    computed_comp = self._compute(comp, root_context)
    type_utils.check_assignable_from(comp.type_signature,
                                     computed_comp.type_signature)

    def _convert_to_py_container(value, type_spec):
      """Converts value to a Python container if type_spec has an annotation."""
      if type_utils.is_anon_tuple_with_py_container(value, type_spec):
        return type_utils.convert_to_py_container(value, type_spec)
      elif isinstance(type_spec, computation_types.SequenceType):
        if all(
            type_utils.is_anon_tuple_with_py_container(
                element, type_spec.element) for element in value):
          return [
              type_utils.convert_to_py_container(element, type_spec.element)
              for element in value
          ]
      return value

    if not isinstance(computed_comp.type_signature,
                      computation_types.FunctionType):
      if arg is not None:
        raise TypeError('Unexpected argument {}.'.format(arg))
      else:
        value = computed_comp.value
        result_type = fn.type_signature.result
        return _convert_to_py_container(value, result_type)
    else:
      if arg is not None:

        def _handle_callable(fn, fn_type):
          py_typecheck.check_type(fn, computation_base.Computation)
          type_utils.check_assignable_from(fn.type_signature, fn_type)
          computed_fn = self._compute(self._compile(fn), root_context)
          return computed_fn.value

        computed_arg = ComputedValue(
            to_representation_for_type(arg,
                                       computed_comp.type_signature.parameter,
                                       _handle_callable),
            computed_comp.type_signature.parameter)
        cardinalities.update(
            runtime_utils.infer_cardinalities(computed_arg.value,
                                              computed_arg.type_signature))
      else:
        computed_arg = None
      result = computed_comp.value(computed_arg)
      py_typecheck.check_type(result, ComputedValue)
      type_utils.check_assignable_from(comp.type_signature.result,
                                       result.type_signature)
      value = result.value
      fn_result_type = fn.type_signature.result
      return _convert_to_py_container(value, fn_result_type)
Exemplo n.º 7
0
 async def create_call(self, comp, arg=None):
     py_typecheck.check_type(comp, FederatingExecutorValue)
     if arg is not None:
         py_typecheck.check_type(arg, FederatingExecutorValue)
         py_typecheck.check_type(comp.type_signature,
                                 computation_types.FunctionType)
         param_type = comp.type_signature.parameter
         type_utils.check_assignable_from(param_type, arg.type_signature)
         arg = FederatingExecutorValue(arg.internal_representation,
                                       param_type)
     if isinstance(comp.internal_representation, pb.Computation):
         which_computation = comp.internal_representation.WhichOneof(
             'computation')
         comp_type_signature = comp.type_signature
         if which_computation == 'lambda':
             # Pull the inner computation out of called no-arg lambdas.
             if comp.type_signature.parameter is not None:
                 raise ValueError(
                     'Directly calling lambdas with arguments is unsupported. '
                     'Found call to lambda with type {}.'.format(
                         comp.type_signature))
             return await self.create_value(
                 getattr(comp.internal_representation, 'lambda').result,
                 comp.type_signature.result)
         elif which_computation == 'tensorflow':
             # Run tensorflow computations.
             child = self._target_executors[None][0]
             embedded_comp = await child.create_value(
                 comp.internal_representation, comp_type_signature)
             if arg is not None:
                 embedded_arg = await executor_utils.delegate_entirely_to_executor(
                     arg.internal_representation, arg.type_signature, child)
             else:
                 embedded_arg = None
             result = await child.create_call(embedded_comp, embedded_arg)
             return FederatingExecutorValue(result, result.type_signature)
         else:
             raise ValueError(
                 'Directly calling computations of type {} is unsupported.'.
                 format(which_computation))
     elif isinstance(comp.internal_representation,
                     intrinsic_defs.IntrinsicDef):
         coro = getattr(
             self, '_compute_intrinsic_{}'.format(
                 comp.internal_representation.uri), None)
         if coro is not None:
             return await coro(arg)  # pylint: disable=not-callable
         else:
             raise NotImplementedError(
                 'Support for intrinsic "{}" has not been implemented yet.'.
                 format(comp.internal_representation.uri))
     else:
         raise ValueError(
             'Calling objects of type {} is unsupported.'.format(
                 py_typecheck.type_string(type(
                     comp.internal_representation))))
def deserialize_value(value_proto):
    """Deserializes a value (of any type) from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a deserialized
    representation of the transmitted value (e.g., Numpy array, or a
    `pb.Computation` instance), and `type_spec` is an instance of
    `tff.TensorType` that represents its type.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    py_typecheck.check_type(value_proto, executor_pb2.Value)
    which_value = value_proto.WhichOneof('value')
    if which_value == 'tensor':
        return deserialize_tensor_value(value_proto)
    elif which_value == 'computation':
        return (value_proto.computation,
                type_serialization.deserialize_type(
                    value_proto.computation.type))
    elif which_value == 'tuple':
        val_elems = []
        type_elems = []
        for e in value_proto.tuple.element:
            name = e.name if e.name else None
            e_val, e_type = deserialize_value(e.value)
            val_elems.append((name, e_val))
            type_elems.append((name, e_type) if name else e_type)
        return (anonymous_tuple.AnonymousTuple(val_elems),
                computation_types.NamedTupleType(type_elems))
    elif which_value == 'sequence':
        return deserialize_sequence_value(value_proto.sequence)
    elif which_value == 'federated':
        type_spec = type_serialization.deserialize_type(
            computation_pb2.Type(federated=value_proto.federated.type))
        value = []
        for item in value_proto.federated.value:
            item_value, item_type = deserialize_value(item)
            type_utils.check_assignable_from(type_spec.member, item_type)
            value.append(item_value)
        if type_spec.all_equal:
            if len(value) == 1:
                value = value[0]
            else:
                raise ValueError(
                    'Return an all_equal value with {} member consatituents.'.
                    format(len(value)))
        return value, type_spec
    else:
        raise ValueError(
            'Unable to deserialize a value of type {}.'.format(which_value))
Exemplo n.º 9
0
 def _sequence_map(self, arg):
   mapping_type = arg.type_signature[0]
   py_typecheck.check_type(mapping_type, computation_types.FunctionType)
   sequence_type = arg.type_signature[1]
   py_typecheck.check_type(sequence_type, computation_types.SequenceType)
   type_utils.check_assignable_from(mapping_type.parameter,
                                    sequence_type.element)
   fn = arg.value[0]
   result_val = [
       fn(ComputedValue(x, mapping_type.parameter)).value for x in arg.value[1]
   ]
   result_type = computation_types.SequenceType(mapping_type.result)
   return ComputedValue(result_val, result_type)
Exemplo n.º 10
0
def stamp_computed_value_into_graph(value, graph):
  """Stamps `value` in `graph`.

  Args:
    value: An instance of `ComputedValue`.
    graph: The graph to stamp in.

  Returns:
    A Python object made of tensors stamped into `graph`, `tf.data.Dataset`s,
    and `anonymous_tuple.AnonymousTuple`s that structurally corresponds to the
    value passed at input.
  """
  if value is None:
    return None
  else:
    py_typecheck.check_type(value, ComputedValue)
    value = ComputedValue(
        to_representation_for_type(value.value, value.type_signature),
        value.type_signature)
    py_typecheck.check_type(graph, tf.Graph)
    if isinstance(value.type_signature, computation_types.TensorType):
      if isinstance(value.value, np.ndarray):
        value_type = computation_types.TensorType(
            tf.dtypes.as_dtype(value.value.dtype),
            tf.TensorShape(value.value.shape))
        type_utils.check_assignable_from(value.type_signature, value_type)
        with graph.as_default():
          return tf.constant(value.value)
      else:
        with graph.as_default():
          return tf.constant(
              value.value,
              dtype=value.type_signature.dtype,
              shape=value.type_signature.shape)
    elif isinstance(value.type_signature, computation_types.NamedTupleType):
      elements = anonymous_tuple.to_elements(value.value)
      type_elements = anonymous_tuple.to_elements(value.type_signature)
      stamped_elements = []
      for idx, (k, v) in enumerate(elements):
        computed_v = ComputedValue(v, type_elements[idx][1])
        stamped_v = stamp_computed_value_into_graph(computed_v, graph)
        stamped_elements.append((k, stamped_v))
      return anonymous_tuple.AnonymousTuple(stamped_elements)
    elif isinstance(value.type_signature, computation_types.SequenceType):
      return tensorflow_utils.make_data_set_from_elements(
          graph, value.value, value.type_signature.element)
    else:
      raise NotImplementedError(
          'Unable to embed a computed value of type {} in graph.'.format(
              value.type_signature))
Exemplo n.º 11
0
def serialize_tensor_value(value, type_spec=None):
    """Serializes a tensor value into `executor_pb2.Value`.

  Args:
    value: A Numpy array or other object understood by `tf.make_tensor_proto`.
    type_spec: An optional type spec, a `tff.TensorType` or something
      convertible to it.

  Returns:
    A tuple `(value_proto, ret_type_spec)` in which `value_proto` is an instance
    of `executor_pb2.Value` with the serialized content of `value`, and
    `ret_type_spec` is the type of the serialized value. The `ret_type_spec` is
    the same as the argument `type_spec` if that argument was not `None`. If
    the argument was `None`, `ret_type_spec` is a type determined from `value`.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    if isinstance(value, tf.Tensor):
        if type_spec is None:
            type_spec = computation_types.TensorType(
                dtype=tf.dtypes.as_dtype(value.dtype),
                shape=tf.TensorShape(value.shape))
        value = value.numpy()
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
        py_typecheck.check_type(type_spec, computation_types.TensorType)
        if isinstance(value, np.ndarray):
            tensor_proto = tf.make_tensor_proto(value,
                                                dtype=type_spec.dtype,
                                                verify_shape=False)
            type_utils.check_assignable_from(
                type_spec,
                computation_types.TensorType(
                    dtype=tf.dtypes.as_dtype(tensor_proto.dtype),
                    shape=tf.TensorShape(tensor_proto.tensor_shape)))
        else:
            tensor_proto = tf.make_tensor_proto(value,
                                                dtype=type_spec.dtype,
                                                shape=type_spec.shape,
                                                verify_shape=True)
    else:
        tensor_proto = tf.make_tensor_proto(value)
        type_spec = computation_types.TensorType(
            dtype=tf.dtypes.as_dtype(tensor_proto.dtype),
            shape=tf.TensorShape(tensor_proto.tensor_shape))
    any_pb = any_pb2.Any()
    any_pb.Pack(tensor_proto)
    return executor_pb2.Value(tensor=any_pb), type_spec
Exemplo n.º 12
0
 def _compute_selection(self, comp, context):
   py_typecheck.check_type(comp, building_blocks.Selection)
   source = self._compute(comp.source, context)
   py_typecheck.check_type(source.type_signature,
                           computation_types.NamedTupleType)
   py_typecheck.check_type(source.value, anonymous_tuple.AnonymousTuple)
   if comp.name is not None:
     result_value = getattr(source.value, comp.name)
     result_type = getattr(source.type_signature, comp.name)
   else:
     assert comp.index is not None
     result_value = source.value[comp.index]
     result_type = source.type_signature[comp.index]
   type_utils.check_assignable_from(comp.type_signature, result_type)
   return ComputedValue(result_value, result_type)
Exemplo n.º 13
0
 def _compute_tuple(self, comp, context):
   py_typecheck.check_type(comp, building_blocks.Tuple)
   result_elements = []
   result_type_elements = []
   for k, v in anonymous_tuple.iter_elements(comp):
     computed_v = self._compute(v, context)
     type_utils.check_assignable_from(v.type_signature,
                                      computed_v.type_signature)
     result_elements.append((k, computed_v.value))
     result_type_elements.append((k, computed_v.type_signature))
   return ComputedValue(
       anonymous_tuple.AnonymousTuple(result_elements),
       computation_types.NamedTupleType([
           (k, v) if k else v for k, v in result_type_elements
       ]))
Exemplo n.º 14
0
 async def create_call(self, comp, arg=None):
     py_typecheck.check_type(comp, LambdaExecutorValue)
     py_typecheck.check_type(comp.type_signature,
                             computation_types.FunctionType)
     param_type = comp.type_signature.parameter
     if param_type is not None:
         py_typecheck.check_type(arg, LambdaExecutorValue)
         arg_type = arg.type_signature  # pytype: disable=attribute-error
         if not type_utils.is_assignable_from(param_type, arg_type):
             if isinstance(arg_type, computation_types.FunctionType):
                 # We may have a non-functional type embedded as a no-arg function.
                 arg_type = type_utils.get_argument_type(arg_type)
             if not type_utils.is_assignable_from(param_type, arg_type):
                 raise TypeError(
                     'LambdaExecutor asked to create call with '
                     'incompatible type specifications. Function '
                     'takes an argument of type {}, but was supplied '
                     'an argument of type {}'.format(param_type, arg_type))
             arg = await self.create_call(arg)
             return await self.create_call(comp, arg)
     else:
         py_typecheck.check_none(arg)
     comp_repr = comp.internal_representation
     if isinstance(comp_repr, executor_value_base.ExecutorValue):
         delegated_arg = await self._delegate(arg
                                              ) if arg is not None else None
         return LambdaExecutorValue(await self._target_executor.create_call(
             comp_repr, delegated_arg))
     elif callable(comp_repr):
         return await comp_repr(arg)
     else:
         # An anonymous tuple could not possibly have a functional type signature,
         # so this is the only case left to handle.
         py_typecheck.check_type(comp_repr, pb.Computation)
         eval_result = await self._evaluate(comp_repr, comp.scope)
         py_typecheck.check_type(eval_result, LambdaExecutorValue)
         if arg is not None:
             py_typecheck.check_type(eval_result.type_signature,
                                     computation_types.FunctionType)
             type_utils.check_assignable_from(
                 eval_result.type_signature.parameter, arg.type_signature)
             return await self.create_call(eval_result, arg)
         elif isinstance(eval_result.type_signature,
                         computation_types.FunctionType):
             return await self.create_call(eval_result, arg=None)
         else:
             return eval_result
Exemplo n.º 15
0
def _stamp_value_into_graph(value, type_signature, graph):
    """Stamps `value` in `graph` as an object of type `type_signature`.

  Args:
    value: An value to stamp.
    type_signature: An instance of `computation_types.Type`.
    graph: The graph to stamp in.

  Returns:
    A Python object made of tensors stamped into `graph`, `tf.data.Dataset`s,
    and `anonymous_tuple.AnonymousTuple`s that structurally corresponds to the
    value passed at input.
  """
    py_typecheck.check_type(type_signature, computation_types.Type)
    py_typecheck.check_type(graph, tf.Graph)
    if value is None:
        return None
    if isinstance(type_signature, computation_types.TensorType):
        if isinstance(value, np.ndarray):
            value_type = computation_types.TensorType(
                tf.dtypes.as_dtype(value.dtype), tf.TensorShape(value.shape))
            type_utils.check_assignable_from(type_signature, value_type)
            with graph.as_default():
                return tf.constant(value)
        else:
            with graph.as_default():
                return tf.constant(value,
                                   dtype=type_signature.dtype,
                                   shape=type_signature.shape)
    elif isinstance(type_signature, computation_types.NamedTupleType):
        if isinstance(value, (list, dict)):
            value = anonymous_tuple.from_container(value)
        stamped_elements = []
        named_type_signatures = anonymous_tuple.to_elements(type_signature)
        for (name, type_signature), element in zip(named_type_signatures,
                                                   value):
            stamped_element = _stamp_value_into_graph(element, type_signature,
                                                      graph)
            stamped_elements.append((name, stamped_element))
        return anonymous_tuple.AnonymousTuple(stamped_elements)
    elif isinstance(type_signature, computation_types.SequenceType):
        return tensorflow_utils.make_data_set_from_elements(
            graph, value, type_signature.element)
    else:
        raise NotImplementedError(
            'Unable to stamp a value of type {} in graph.'.format(
                type_signature))
Exemplo n.º 16
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     py_typecheck.check_type(type_spec, computation_types.Type)
     hashable_key = _get_hashable_key(value, type_spec)
     try:
         identifier = self._cache[hashable_key]
     except KeyError:
         identifier = None
     except TypeError as err:
         raise RuntimeError(
             'Failed to perform a has table lookup with a value of Python '
             'type {} and TFF type {}, and payload {}: {}'.format(
                 py_typecheck.type_string(type(value)), type_spec, value,
                 err))
     if isinstance(identifier, CachedValueIdentifier):
         try:
             cached_value = self._cache[identifier]
         except KeyError:
             cached_value = None
         # If may be that the same payload appeared with a mismatching type spec,
         # which may be a legitimate use case if (as it happens) the payload alone
         # does not uniquely determine the type, so we simply opt not to reuse the
         # cache value and fallback on the regular behavior.
         if type_spec is not None and not type_utils.are_equivalent_types(
                 cached_value.type_signature, type_spec):
             identifier = None
     else:
         identifier = None
     if identifier is None:
         self._num_values_created = self._num_values_created + 1
         identifier = CachedValueIdentifier(str(self._num_values_created))
         self._cache[hashable_key] = identifier
         target_future = asyncio.ensure_future(
             self._target_executor.create_value(value, type_spec))
         cached_value = None
     if cached_value is None:
         cached_value = CachedValue(identifier, hashable_key, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     target_value = await cached_value.target_future
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 17
0
 async def create_call(self, comp, arg=None):
     py_typecheck.check_type(comp, CompositeValue)
     if arg is not None:
         py_typecheck.check_type(arg, CompositeValue)
         py_typecheck.check_type(comp.type_signature,
                                 computation_types.FunctionType)
         param_type = comp.type_signature.parameter
         type_utils.check_assignable_from(param_type, arg.type_signature)
         arg = CompositeValue(arg.internal_representation, param_type)
     if isinstance(comp.internal_representation, pb.Computation):
         which_computation = comp.internal_representation.WhichOneof(
             'computation')
         if which_computation == 'tensorflow':
             call_args = [
                 self._parent_executor.create_value(
                     comp.internal_representation, comp.type_signature)
             ]
             if arg is not None:
                 call_args.append(
                     executor_utils.delegate_entirely_to_executor(
                         arg.internal_representation, arg.type_signature,
                         self._parent_executor))
             result = await self._parent_executor.create_call(
                 *(await asyncio.gather(*call_args)))
             return CompositeValue(result, result.type_signature)
         else:
             raise ValueError(
                 'Directly calling computations of type {} is unsupported.'.
                 format(which_computation))
     elif isinstance(comp.internal_representation,
                     intrinsic_defs.IntrinsicDef):
         coro = getattr(
             self, '_compute_intrinsic_{}'.format(
                 comp.internal_representation.uri), None)
         if coro is not None:
             return await coro(arg)  # pylint: disable=not-callable
         else:
             raise NotImplementedError(
                 'Support for intrinsic "{}" has not been implemented yet.'.
                 format(comp.internal_representation.uri))
     else:
         raise ValueError(
             'Calling objects of type {} is unsupported.'.format(
                 py_typecheck.type_string(type(
                     comp.internal_representation))))
Exemplo n.º 18
0
 def _sequence_reduce(self, arg):
   py_typecheck.check_type(arg.type_signature,
                           computation_types.NamedTupleType)
   sequence_type = arg.type_signature[0]
   py_typecheck.check_type(sequence_type, computation_types.SequenceType)
   zero_type = arg.type_signature[1]
   op_type = arg.type_signature[2]
   py_typecheck.check_type(op_type, computation_types.FunctionType)
   type_utils.check_assignable_from(op_type.parameter,
                                    [zero_type, sequence_type.element])
   total = ComputedValue(arg.value[1], zero_type)
   reduce_fn = arg.value[2]
   for v in arg.value[0]:
     total = reduce_fn(
         ComputedValue(
             anonymous_tuple.AnonymousTuple([(None, total.value), (None, v)]),
             op_type.parameter))
   return total
Exemplo n.º 19
0
 def _compute_call(self, comp, context):
   py_typecheck.check_type(comp, building_blocks.Call)
   computed_fn = self._compute(comp.function, context)
   py_typecheck.check_type(computed_fn.type_signature,
                           computation_types.FunctionType)
   if comp.argument is not None:
     computed_arg = self._compute(comp.argument, context)
     type_utils.check_assignable_from(computed_fn.type_signature.parameter,
                                      computed_arg.type_signature)
     computed_arg = fit_argument(computed_arg,
                                 computed_fn.type_signature.parameter, context)
   else:
     computed_arg = None
   result = computed_fn.value(computed_arg)
   py_typecheck.check_type(result, ComputedValue)
   type_utils.check_assignable_from(computed_fn.type_signature.result,
                                    result.type_signature)
   return result
Exemplo n.º 20
0
 async def create_tuple(self, elements):
     if not isinstance(elements, anonymous_tuple.AnonymousTuple):
         elements = anonymous_tuple.from_container(elements)
     element_strings = []
     element_kv_pairs = anonymous_tuple.to_elements(elements)
     to_gather = []
     type_elements = []
     for k, v in element_kv_pairs:
         py_typecheck.check_type(v, CachedValue)
         to_gather.append(v.target_future)
         if k is not None:
             py_typecheck.check_type(k, str)
             element_strings.append('{}={}'.format(k, v.identifier))
             type_elements.append((k, v.type_signature))
         else:
             element_strings.append(str(v.identifier))
             type_elements.append(v.type_signature)
     type_spec = computation_types.NamedTupleType(type_elements)
     gathered = await asyncio.gather(*to_gather)
     identifier = CachedValueIdentifier('<{}>'.format(
         ','.join(element_strings)))
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_tuple(
                 anonymous_tuple.AnonymousTuple(
                     (k, v)
                     for (k, _), v in zip(element_kv_pairs, gathered))))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     try:
         target_value = await cached_value.target_future
     except Exception as e:
         # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
         # only the current cache item needs to be invalidated; however this
         # currently only occurs when an inner RemoteExecutor has the backend go
         # down.
         self._cache = {}
         raise e
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 21
0
 async def create_call(self, comp, arg=None):
     py_typecheck.check_type(comp, FederatedExecutorValue)
     if arg is not None:
         py_typecheck.check_type(arg, FederatedExecutorValue)
         py_typecheck.check_type(comp.type_signature,
                                 computation_types.FunctionType)
         param_type = comp.type_signature.parameter
         type_utils.check_assignable_from(param_type, arg.type_signature)
         arg = FederatedExecutorValue(arg.internal_representation,
                                      param_type)
     if isinstance(comp.internal_representation, pb.Computation):
         which_computation = comp.internal_representation.WhichOneof(
             'computation')
         if which_computation == 'tensorflow':
             child = self._target_executors[None][0]
             embedded_comp = await child.create_value(
                 comp.internal_representation, comp.type_signature)
             if arg is not None:
                 embedded_arg = await self._delegate(
                     child, arg.internal_representation, arg.type_signature)
             else:
                 embedded_arg = None
             result = await child.create_call(embedded_comp, embedded_arg)
             return FederatedExecutorValue(result, result.type_signature)
         else:
             raise ValueError(
                 'Directly calling computations of type {} is unsupported.'.
                 format(which_computation))
     if isinstance(comp.internal_representation,
                   intrinsic_defs.IntrinsicDef):
         coro = getattr(
             self, '_compute_intrinsic_{}'.format(
                 comp.internal_representation.uri))
         if coro is not None:
             return await coro(arg)
         else:
             raise NotImplementedError(
                 'Support for intrinsic "{}" has not been implemented yet.'.
                 format(comp.internal_representation.uri))
     else:
         raise ValueError(
             'Calling objects of type {} is unsupported.'.format(
                 py_typecheck.type_string(type(
                     comp.internal_representation))))
Exemplo n.º 22
0
 def _federated_reduce(self, arg):
   py_typecheck.check_type(arg.type_signature,
                           computation_types.NamedTupleType)
   federated_type = arg.type_signature[0]
   type_utils.check_federated_type(federated_type, None, placements.CLIENTS,
                                   False)
   zero_type = arg.type_signature[1]
   op_type = arg.type_signature[2]
   py_typecheck.check_type(op_type, computation_types.FunctionType)
   type_utils.check_assignable_from(op_type.parameter,
                                    [zero_type, federated_type.member])
   total = ComputedValue(arg.value[1], zero_type)
   reduce_fn = arg.value[2]
   for v in arg.value[0]:
     total = reduce_fn(
         ComputedValue(
             anonymous_tuple.AnonymousTuple([(None, total.value), (None, v)]),
             op_type.parameter))
   return self._federated_value_at_server(total)
Exemplo n.º 23
0
 async def create_selection(self, source, index=None, name=None):
     py_typecheck.check_type(source, CachedValue)
     py_typecheck.check_type(source.type_signature,
                             computation_types.NamedTupleType)
     source_val = await source.target_future
     if index is not None:
         py_typecheck.check_none(name)
         identifier_str = '{}[{}]'.format(source.identifier, index)
         type_spec = source.type_signature[index]
     else:
         py_typecheck.check_not_none(name)
         identifier_str = '{}.{}'.format(source.identifier, name)
         type_spec = getattr(source.type_signature, name)
     identifier = CachedValueIdentifier(identifier_str)
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_selection(source_val,
                                                    index=index,
                                                    name=name))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     try:
         target_value = await cached_value.target_future
     except Exception as e:
         # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
         # only the current cache item needs to be invalidated; however this
         # currently only occurs when an inner RemoteExecutor has the backend go
         # down.
         self._cache = {}
         raise e
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 24
0
 async def create_call(self, comp, arg=None):
     py_typecheck.check_type(comp, CachedValue)
     py_typecheck.check_type(comp.type_signature,
                             computation_types.FunctionType)
     to_gather = [comp.target_future]
     if arg is not None:
         py_typecheck.check_type(arg, CachedValue)
         type_utils.check_assignable_from(comp.type_signature.parameter,
                                          arg.type_signature)
         to_gather.append(arg.target_future)
         identifier_str = '{}({})'.format(comp.identifier, arg.identifier)
     else:
         identifier_str = '{}()'.format(comp.identifier)
     gathered = await asyncio.gather(*to_gather)
     type_spec = comp.type_signature.result
     identifier = CachedValueIdentifier(identifier_str)
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_call(*gathered))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     try:
         target_value = await cached_value.target_future
     except Exception as e:
         # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
         # only the current cache item needs to be invalidated; however this
         # currently only occurs when an inner RemoteExecutor has the backend go
         # down.
         self._cache = {}
         raise e
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 25
0
 async def create_tuple(self, elements):
     if not isinstance(elements, anonymous_tuple.AnonymousTuple):
         elements = anonymous_tuple.from_container(elements)
     element_strings = []
     element_kv_pairs = anonymous_tuple.to_elements(elements)
     to_gather = []
     type_elements = []
     for k, v in element_kv_pairs:
         py_typecheck.check_type(v, CachedValue)
         to_gather.append(v.target_future)
         if k is not None:
             py_typecheck.check_type(k, str)
             element_strings.append('{}={}'.format(k, v.identifier))
             type_elements.append((k, v.type_signature))
         else:
             element_strings.append(str(v.identifier))
             type_elements.append(v.type_signature)
     type_spec = computation_types.NamedTupleType(type_elements)
     gathered = await asyncio.gather(*to_gather)
     identifier = CachedValueIdentifier('<{}>'.format(
         ','.join(element_strings)))
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_tuple(
                 anonymous_tuple.AnonymousTuple([
                     (k, v) for (k, _), v in zip(element_kv_pairs, gathered)
                 ])))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     target_value = await cached_value.target_future
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Exemplo n.º 26
0
  def __init__(self, initialize, prepare, work, zero, accumulate, merge, report,
               bitwidth, update):
    """Constructs a representation of a MapReduce-like iterative process.

    Note: All the computations supplied here as arguments must be TensorFlow
    computations, i.e., instances of `tff.Computation` constructed by the
    `tff.tf_computation` decorator/wrapper.

    Args:
      initialize: The computation that produces the initial server state.
      prepare: The computation that prepares the input for the clients.
      work: The client-side work computation.
      zero: The computation that produces the initial state for accumulators.
      accumulate: The computation that adds a client update to an accumulator.
      merge: The computation to use for merging pairs of accumulators.
      report: The computation that produces the final server-side aggregate for
        the top level accumulator (the global update).
      bitwidth: The computation that produces the bitwidth for secure sum.
      update: The computation that takes the global update and the server state
        and produces the new server state, as well as server-side output.

    Raises:
      TypeError: If the Python or TFF types of the arguments are invalid or not
        compatible with each other.
      AssertionError: If the manner in which the given TensorFlow computations
        are represented by TFF does not match what this code is expecting (this
        is an internal error that requires code update).
    """
    for label, comp in [
        ('initialize', initialize),
        ('prepare', prepare),
        ('work', work),
        ('zero', zero),
        ('accumulate', accumulate),
        ('merge', merge),
        ('report', report),
        ('bitwidth', bitwidth),
        ('update', update),
    ]:
      py_typecheck.check_type(comp, computation_base.Computation, label)

      # TODO(b/130633916): Remove private access once an appropriate API for it
      # becomes available.
      comp_proto = comp._computation_proto  # pylint: disable=protected-access

      if not isinstance(comp_proto, computation_pb2.Computation):
        # Explicitly raised to force it to be done in non-debug mode as well.
        raise AssertionError('Cannot find the embedded computation definition.')
      which_comp = comp_proto.WhichOneof('computation')
      if which_comp != 'tensorflow':
        raise TypeError('Expected all computations supplied as arguments to '
                        'be plain TensorFlow, found {}.'.format(which_comp))

    if prepare.type_signature.parameter != initialize.type_signature.result:
      raise TypeError(
          'The `prepare` computation expects an argument of type {}, '
          'which does not match the result type {} of `initialize`.'.format(
              prepare.type_signature.parameter,
              initialize.type_signature.result))

    if (not isinstance(work.type_signature.parameter,
                       computation_types.NamedTupleType) or
        len(work.type_signature.parameter) != 2):
      raise TypeError(
          'The `work` computation expects an argument of type {} that is not '
          'a two-tuple.'.format(work.type_signature.parameter))

    if work.type_signature.parameter[1] != prepare.type_signature.result:
      raise TypeError(
          'The `work` computation expects an argument tuple with type {} as '
          'the second element (the initial client state from the server), '
          'which does not match the result type {} of `prepare`.'.format(
              work.type_signature.parameter[1], prepare.type_signature.result))

    if (not isinstance(work.type_signature.result,
                       computation_types.NamedTupleType) or
        len(work.type_signature.result) != 2):
      raise TypeError(
          'The `work` computation returns a result  of type {} that is not a '
          'two-tuple.'.format(work.type_signature.result))

    py_typecheck.check_type(zero.type_signature, computation_types.FunctionType)

    py_typecheck.check_type(accumulate.type_signature,
                            computation_types.FunctionType)
    py_typecheck.check_len(accumulate.type_signature.parameter, 2)
    type_utils.check_assignable_from(accumulate.type_signature.parameter[0],
                                     zero.type_signature.result)
    if (accumulate.type_signature.parameter[1] !=
        work.type_signature.result[0][0]):

      raise TypeError(
          'The `accumulate` computation expects a second argument of type {}, '
          'which does not match the expected {} as implied by the type '
          'signature of `work`.'.format(accumulate.type_signature.parameter[1],
                                        work.type_signature.result[0][0]))
    type_utils.check_assignable_from(accumulate.type_signature.parameter[0],
                                     accumulate.type_signature.result)

    py_typecheck.check_type(merge.type_signature,
                            computation_types.FunctionType)
    py_typecheck.check_len(merge.type_signature.parameter, 2)
    type_utils.check_assignable_from(merge.type_signature.parameter[0],
                                     accumulate.type_signature.result)
    type_utils.check_assignable_from(merge.type_signature.parameter[1],
                                     accumulate.type_signature.result)
    type_utils.check_assignable_from(merge.type_signature.parameter[0],
                                     merge.type_signature.result)

    py_typecheck.check_type(report.type_signature,
                            computation_types.FunctionType)
    type_utils.check_assignable_from(report.type_signature.parameter,
                                     merge.type_signature.result)

    py_typecheck.check_type(bitwidth.type_signature,
                            computation_types.FunctionType)

    expected_update_parameter_type = computation_types.to_type([
        initialize.type_signature.result,
        [report.type_signature.result, work.type_signature.result[0][1]],
    ])
    if update.type_signature.parameter != expected_update_parameter_type:
      raise TypeError(
          'The `update` computation expects an argument of type {}, '
          'which does not match the expected {} as implied by the type '
          'signatures of `initialize`, `report`, and `work`.'.format(
              update.type_signature.parameter, expected_update_parameter_type))

    if (not isinstance(update.type_signature.result,
                       computation_types.NamedTupleType) or
        len(update.type_signature.result) != 2):
      raise TypeError(
          'The `update` computation returns a result  of type {} that is not '
          'a two-tuple.'.format(update.type_signature.result))

    if update.type_signature.result[0] != initialize.type_signature.result:
      raise TypeError(
          'The `update` computation returns a result tuple with type {} as '
          'the first element (the updated state of the server), which does '
          'not match the result type {} of `initialize`.'.format(
              update.type_signature.result[0],
              initialize.type_signature.result))

    self._initialize = initialize
    self._prepare = prepare
    self._work = work
    self._zero = zero
    self._accumulate = accumulate
    self._merge = merge
    self._report = report
    self._bitwidth = bitwidth
    self._update = update
Exemplo n.º 27
0
  async def create_call(self, comp, arg=None):
    """A coroutine that creates a call to `comp` with optional argument `arg`.

    Args:
      comp: The computation to invoke.
      arg: An optional argument of the call, or `None` if no argument was
        supplied.

    Returns:
      An instance of `FederatingExecutorValue` that represents the constructed
      call.

    Raises:
      TypeError: If `comp` or `arg` are not embedded in the executor, before
        calling `create_call` or if the `type_signature` of `arg` do not match
        the expected `type_signature` of the parameter to `comp`.
      ValueError: If `comp` is not a functional kind recognized by
        `FederatingExecutor` or if `comp` is a lambda with an argument.
      NotImplementedError: If `comp` is an intrinsic and it has not been
        implemented by the `FederatingExecutor`.
    """
    py_typecheck.check_type(comp, FederatingExecutorValue)
    if arg is not None:
      py_typecheck.check_type(arg, FederatingExecutorValue)
      py_typecheck.check_type(comp.type_signature,
                              computation_types.FunctionType)
      param_type = comp.type_signature.parameter
      type_utils.check_assignable_from(param_type, arg.type_signature)
      arg = FederatingExecutorValue(arg.internal_representation, param_type)
    if isinstance(comp.internal_representation, pb.Computation):
      which_computation = comp.internal_representation.WhichOneof('computation')
      if which_computation == 'lambda':
        if comp.type_signature.parameter is not None:
          raise ValueError(
              'Directly calling lambdas with arguments is unsupported. '
              'Found call to lambda with type {}.'.format(comp.type_signature))
        # Pull the inner computation out of called no-arg lambdas.
        return await self.create_value(
            getattr(comp.internal_representation, 'lambda').result,
            comp.type_signature.result)
      elif which_computation == 'tensorflow':
        # Run tensorflow computations.
        child = self._target_executors[None][0]
        embedded_comp = await child.create_value(comp.internal_representation,
                                                 comp.type_signature)
        if arg is not None:
          embedded_arg = await executor_utils.delegate_entirely_to_executor(
              arg.internal_representation, arg.type_signature, child)
        else:
          embedded_arg = None
        result = await child.create_call(embedded_comp, embedded_arg)
        return FederatingExecutorValue(result, result.type_signature)
      else:
        raise ValueError(
            'Directly calling computations of type {} is unsupported.'.format(
                which_computation))
    elif isinstance(comp.internal_representation, intrinsic_defs.IntrinsicDef):
      coro = getattr(
          self,
          '_compute_intrinsic_{}'.format(comp.internal_representation.uri),
          None)
      if coro is not None:
        return await coro(arg)  # pylint: disable=not-callable
      else:
        raise NotImplementedError(
            'Support for intrinsic "{}" has not been implemented yet.'.format(
                comp.internal_representation.uri))
    else:
      raise ValueError('Calling objects of type {} is unsupported.'.format(
          py_typecheck.type_string(type(comp.internal_representation))))
Exemplo n.º 28
0
 def _handle_callable(fn, fn_type):
   py_typecheck.check_type(fn, computation_base.Computation)
   type_utils.check_assignable_from(fn.type_signature, fn_type)
   computed_fn = self._compute(self._compile(fn), root_context)
   return computed_fn.value
Exemplo n.º 29
0
def to_representation_for_type(value, type_spec=None, device=None):
    """Verifies or converts the `value` to an eager object matching `type_spec`.

  WARNING: This function is only partially implemented. It does not support
  data sets at this point.

  The output of this function is always an eager tensor, eager dataset, a
  representation of a TensorFlow computation, or a nested structure of those
  that matches `type_spec`, and when `device` has been specified, everything
  is placed on that device on a best-effort basis.

  TensorFlow computations are represented here as zero- or one-argument Python
  callables that accept their entire argument bundle as a single Python object.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`, can be `None` for values that derive
      from `typed_object.TypedObject`.
    device: The optional device to place the value on (for tensor-level values).

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
  """
    type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec)
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value), type_spec,
            device)
    elif isinstance(value, pb.Computation):
        return embed_tensorflow_computation(value, type_spec, device)
    elif isinstance(type_spec, computation_types.NamedTupleType):
        type_elem = anonymous_tuple.to_elements(type_spec)
        value_elem = (anonymous_tuple.to_elements(
            anonymous_tuple.from_container(value)))
        result_elem = []
        if len(type_elem) != len(value_elem):
            raise TypeError(
                'Expected a {}-element tuple, found {} elements.'.format(
                    len(type_elem), len(value_elem)))
        for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem):
            if t_name != v_name:
                raise TypeError(
                    'Mismatching element names in type vs. value: {} vs. {}.'.
                    format(t_name, v_name))
            el_repr = to_representation_for_type(el_val, el_type, device)
            result_elem.append((t_name, el_repr))
        return anonymous_tuple.AnonymousTuple(result_elem)
    elif device is not None:
        py_typecheck.check_type(device, str)
        with tf.device(device):
            return to_representation_for_type(value,
                                              type_spec=type_spec,
                                              device=None)
    elif isinstance(value, EagerValue):
        return value.internal_representation
    elif isinstance(value, executor_value_base.ExecutorValue):
        raise TypeError(
            'Cannot accept a value embedded within a non-eager executor.')
    elif isinstance(type_spec, computation_types.TensorType):
        if not tf.is_tensor(value):
            value = tf.convert_to_tensor(value, dtype=type_spec.dtype)
        elif hasattr(value, 'read_value'):
            # a tf.Variable-like result, get a proper tensor.
            value = value.read_value()
        value_type = (computation_types.TensorType(value.dtype.base_dtype,
                                                   value.shape))
        if not type_utils.is_assignable_from(type_spec, value_type):
            raise TypeError(
                'The apparent type {} of a tensor {} does not match the expected '
                'type {}.'.format(value_type, value, type_spec))
        return value
    elif isinstance(type_spec, computation_types.SequenceType):
        if isinstance(value, list):
            value = tensorflow_utils.make_data_set_from_elements(
                None, value, type_spec.element)
        py_typecheck.check_type(value,
                                type_utils.TF_DATASET_REPRESENTATION_TYPES)
        element_type = computation_types.to_type(
            tf.data.experimental.get_structure(value))
        value_type = computation_types.SequenceType(element_type)
        type_utils.check_assignable_from(type_spec, value_type)
        return value
    else:
        raise TypeError('Unexpected type {}.'.format(type_spec))
Exemplo n.º 30
0
def serialize_value(value, type_spec=None):
    """Serializes a value into `executor_pb2.Value`.

  Args:
    value: A value to be serialized.
    type_spec: Optional type spec, a `tff.Type` or something convertible to it.

  Returns:
    A tuple `(value_proto, ret_type_spec)` where `value_proto` is an instance
    of `executor_pb2.Value` with the serialized content of `value`, and the
    returned `ret_type_spec` is an instance of `tff.Type` that represents the
    TFF type of the serialized value.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    type_spec = computation_types.to_type(type_spec)
    if isinstance(value, computation_pb2.Computation):
        type_spec = type_utils.reconcile_value_type_with_type_spec(
            type_serialization.deserialize_type(value.type), type_spec)
        return executor_pb2.Value(computation=value), type_spec
    elif isinstance(value, computation_impl.ComputationImpl):
        return serialize_value(
            computation_impl.ComputationImpl.get_proto(value),
            type_utils.reconcile_value_with_type_spec(value, type_spec))
    elif isinstance(type_spec, computation_types.TensorType):
        return serialize_tensor_value(value, type_spec)
    elif isinstance(type_spec, computation_types.NamedTupleType):
        type_elements = anonymous_tuple.to_elements(type_spec)
        val_elements = anonymous_tuple.to_elements(
            anonymous_tuple.from_container(value))
        tup_elems = []
        for (e_name, e_type), (_, e_val) in zip(type_elements, val_elements):
            e_proto, _ = serialize_value(e_val, e_type)
            tup_elems.append(
                executor_pb2.Value.Tuple.Element(
                    name=e_name if e_name else None, value=e_proto))
        result_proto = (executor_pb2.Value(tuple=executor_pb2.Value.Tuple(
            element=tup_elems)))
        return result_proto, type_spec
    elif isinstance(type_spec, computation_types.SequenceType):
        if not isinstance(value,
                          type_conversions.TF_DATASET_REPRESENTATION_TYPES):
            raise TypeError(
                'Cannot serialize Python type {!s} as TFF type {!s}.'.format(
                    py_typecheck.type_string(type(value)),
                    type_spec if type_spec is not None else 'unknown'))

        value_type = computation_types.SequenceType(
            computation_types.to_type(value.element_spec))
        if not type_utils.is_assignable_from(type_spec, value_type):
            raise TypeError(
                'Cannot serialize dataset with elements of type {!s} as TFF type {!s}.'
                .format(value_type,
                        type_spec if type_spec is not None else 'unknown'))

        return serialize_sequence_value(value), type_spec
    elif isinstance(type_spec, computation_types.FederatedType):
        if type_spec.all_equal:
            value = [value]
        else:
            py_typecheck.check_type(value, list)
        items = []
        for v in value:
            it, it_type = serialize_value(v, type_spec.member)
            type_utils.check_assignable_from(type_spec.member, it_type)
            items.append(it)
        result_proto = executor_pb2.Value(
            federated=executor_pb2.Value.Federated(
                type=type_serialization.serialize_type(type_spec).federated,
                value=items))
        return result_proto, type_spec
    else:
        raise ValueError(
            'Unable to serialize value with Python type {} and {} TFF type.'.
            format(str(py_typecheck.type_string(type(value))),
                   str(type_spec) if type_spec is not None else 'unknown'))