Example #1
0
 def assert_serializes(self, fn, parameter_type, expected_fn_type_str):
   serializer = tensorflow_serialization.tf_computation_serializer(
       parameter_type, context_stack_impl.context_stack)
   arg_to_fn = next(serializer)
   result = fn(arg_to_fn)
   comp, extra_type_spec = serializer.send(result)
   deserialized_type = type_serialization.deserialize_type(comp.type)
   type_test_utils.assert_types_equivalent(deserialized_type, extra_type_spec)
   self.assertEqual(deserialized_type.compact_representation(),
                    expected_fn_type_str)
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   return comp.tensorflow, extra_type_spec
Example #2
0
def _tf_wrapper_fn(parameter_type, name):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  tf_serializer = tensorflow_serialization.tf_computation_serializer(
      parameter_type, ctx_stack)
  result = yield next(tf_serializer)
  comp_pb, extra_type_spec = tf_serializer.send(result)
  yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(parameter_type, name):
    """Wrapper function to plug Tensorflow logic into the TFF framework."""
    del name  # Unused.
    if not type_analysis.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `StructType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    tf_serializer = tensorflow_serialization.tf_computation_serializer(
        parameter_type, ctx_stack)
    arg = next(tf_serializer)
    try:
        result = yield arg
    except Exception as e:  # pylint: disable=broad-except
        tf_serializer.throw(e)
    comp_pb, extra_type_spec = tf_serializer.send(result)
    tf_serializer.close()
    yield computation_impl.ConcreteComputation(comp_pb, ctx_stack,
                                               extra_type_spec)
Example #4
0
def _tf_computation_serializer(fn, parameter_type, context):
    serializer = tensorflow_serialization.tf_computation_serializer(
        parameter_type, context)
    arg_to_fn = next(serializer)
    result = fn(arg_to_fn)
    return serializer.send(result)