def should_transform(self, comp):
   if not (type_analysis.is_tensorflow_compatible_type(comp.type_signature) or
           (comp.type_signature.is_function() and
            type_analysis.is_tensorflow_compatible_type(
                comp.type_signature.parameter) and
            type_analysis.is_tensorflow_compatible_type(
                comp.type_signature.result))):
     return False
   elif comp.is_compiled_computation() or (
       comp.is_call() and comp.function.is_compiled_computation()):
     # These represent the final result of TF generation; no need to transform.
     return False
   unbound_refs = transformation_utils.get_map_of_unbound_references(
       comp)[comp]
   if unbound_refs:
     # We cannot represent these captures without further information.
     return False
   if tree_analysis.contains_types(comp, building_blocks.Intrinsic):
     return False
   return True
Exemple #2
0
 def is_allowed_client_data_type(
         type_spec: computation_types.Type) -> bool:
     if type_spec.is_sequence():
         return type_analysis.is_tensorflow_compatible_type(
             type_spec.element)
     elif type_spec.is_struct():
         return all(
             is_allowed_client_data_type(element_type)
             for element_type in type_spec.children())
     else:
         return False
Exemple #3
0
def _tf_wrapper_fn(parameter_type, name):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  tf_serializer = tensorflow_serialization.tf_computation_serializer(
      parameter_type, ctx_stack)
  result = yield next(tf_serializer)
  comp_pb, extra_type_spec = tf_serializer.send(result)
  yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
      target_fn, parameter_type, unpack)
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
      target_fn, parameter_type, ctx_stack)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(parameter_type, name):
    """Wrapper function to plug Tensorflow logic into the TFF framework."""
    del name  # Unused.
    if not type_analysis.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `StructType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    tf_serializer = tensorflow_serialization.tf_computation_serializer(
        parameter_type, ctx_stack)
    arg = next(tf_serializer)
    try:
        result = yield arg
    except Exception as e:  # pylint: disable=broad-except
        tf_serializer.throw(e)
    comp_pb, extra_type_spec = tf_serializer.send(result)
    tf_serializer.close()
    yield computation_impl.ConcreteComputation(comp_pb, ctx_stack,
                                               extra_type_spec)