def create_lambda_identity(
        type_spec: computation_types.Type) -> pb.Computation:
    """Returns a lambda computation representing an identity function.

  Has the type signature:

  (T -> T)

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    An instance of `pb.Computation`.
  """
    type_signature = type_factory.unary_op(type_spec)
    result = pb.Computation(type=type_serialization.serialize_type(type_spec),
                            reference=pb.Reference(name='a'))
    fn = pb.Lambda(parameter_name='a', result=result)
    # We are unpacking the lambda argument here because `lambda` is a reserved
    # keyword in Python, but it is also the name of the parameter for a
    # `pb.Computation`.
    # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        **{'lambda': fn})  # pytype: disable=wrong-keyword-args
def create_dummy_computation_reference():
  """Returns a reference computation and type."""
  type_signature = computation_types.TensorType(tf.float32)
  value = pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      reference=pb.Reference(name='a'))
  return value, type_signature
Esempio n. 3
0
 def test_executor_create_value_with_unbound_reference(self):
     loop = asyncio.get_event_loop()
     ex = _make_test_executor()
     with self.assertRaises(ValueError):
         loop.run_until_complete(
             ex.create_value(
                 pb.Computation(reference=pb.Reference(name='a'),
                                type=type_serialization.serialize_type(
                                    tf.int32))))
Esempio n. 4
0
def _create_lambda_identity_comp(type_spec):
    """Returns a `pb.Computation` representing an identity function."""
    py_typecheck.check_type(type_spec, computation_types.Type)
    type_signature = type_serialization.serialize_type(
        type_factory.unary_op(type_spec))
    result = pb.Computation(type=type_serialization.serialize_type(type_spec),
                            reference=pb.Reference(name='x'))
    fn = pb.Lambda(parameter_name='x', result=result)
    # We are unpacking the lambda argument here because `lambda` is a reserved
    # keyword in Python, but it is also the name of the parameter for a
    # `pb.Computation`.
    # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
    return pb.Computation(type=type_signature, **{'lambda': fn})  # pytype: disable=wrong-keyword-args
Esempio n. 5
0
def _create_lambda_identity_comp(type_spec):
    py_typecheck.check_type(type_spec, computation_types.Type)
    return pb.Computation(
        **{
            'type':
            type_serialization.serialize_type(type_factory.unary_op(
                type_spec)),
            'lambda':
            pb.Lambda(parameter_name='x',
                      result=pb.Computation(
                          type=type_serialization.serialize_type(type_spec),
                          reference=pb.Reference(name='x')))
        })
Esempio n. 6
0
def create_dummy_computation_lambda_identity():
  """Returns a lambda computation and type `(float32 -> float32)`."""
  type_signature = type_factory.unary_op(tf.float32)
  result = pb.Computation(
      type=type_serialization.serialize_type(tf.float32),
      reference=pb.Reference(name='a'))
  fn = pb.Lambda(parameter_name='a', result=result)
  # We are unpacking the lambda argument here because `lambda` is a reserved
  # keyword in Python, but it is also the name of the parameter for a
  # `pb.Computation`.
  # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
  value = pb.Computation(
      type=type_serialization.serialize_type(type_signature), **{'lambda': fn})  # pytype: disable=wrong-keyword-args
  return value, type_signature
Esempio n. 7
0
    def test_invoke_raises_value_error_with_federated_computation(self):
        bogus_proto = pb.Computation(type=type_serialization.serialize_type(
            computation_types.to_type(
                computation_types.FunctionType(tf.int32, tf.int32))),
                                     reference=pb.Reference(name='boogledy'))
        non_tf_computation = computation_impl.ComputationImpl(
            bogus_proto, context_stack_impl.context_stack)

        context = tensorflow_computation_context.TensorFlowComputationContext(
            tf.compat.v1.get_default_graph())

        with self.assertRaisesRegex(
                ValueError, 'Can only invoke TensorFlow in the body of '
                'a TensorFlow computation'):
            context.invoke(non_tf_computation, None)
Esempio n. 8
0
def create_dummy_identity_lambda_computation(type_spec=tf.int32):
    """Returns a `pb.Computation` representing an identity lambda.

  The type signature of this `pb.Computation` is:

  (int32 -> int32)

  Args:
    type_spec: A type signature.

  Returns:
    A `pb.Computation`.
  """
    type_signature = type_serialization.serialize_type(
        type_factory.unary_op(type_spec))
    result = pb.Computation(type=type_serialization.serialize_type(type_spec),
                            reference=pb.Reference(name='a'))
    fn = pb.Lambda(parameter_name='a', result=result)
    # We are unpacking the lambda argument here because `lambda` is a reserved
    # keyword in Python, but it is also the name of the parameter for a
    # `pb.Computation`.
    # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
    return pb.Computation(type=type_signature, **{'lambda': fn})  # pytype: disable=wrong-keyword-args
Esempio n. 9
0
 def proto(self):
     return pb.Computation(type=type_serialization.serialize_type(
         self.type_signature),
                           reference=pb.Reference(name=self._name))
Esempio n. 10
0
 def test_executor_create_value_with_unbound_reference(self):
     with self.assertRaises(ValueError):
         _produce_test_value(
             pb.Computation(reference=pb.Reference(name='a'),
                            type=type_serialization.serialize_type(
                                tf.int32)))
def create_dummy_computation_reference():
    value = pb.Computation(type=type_serialization.serialize_type(tf.int32),
                           reference=pb.Reference(name='a'))
    type_signature = computation_types.TensorType(tf.int32)
    return value, type_signature