def test_raises_type_error_with_value_and_bad_type(self):
        value = eager_tf_executor.EagerValue(10.0, None, tf.float32)
        bad_type_signature = None

        with self.assertRaises(TypeError):
            federating_executor.FederatingExecutorValue(
                value, bad_type_signature)
Esempio n. 2
0
  def test_raises_runtime_error_with_unsupported_value_or_type(self):
    value = 10.0
    type_signature = computation_types.TensorType(tf.float32)
    value = federating_executor.FederatingExecutorValue(value, type_signature)

    with self.assertRaises(RuntimeError):
      self.run_sync(value.compute())
Esempio n. 3
0
  def test_raises_type_error_with_unembedded_federated_type(self):
    value = [10.0, 11.0, 12.0]
    type_signature = type_factory.at_clients(tf.float32)
    value = federating_executor.FederatingExecutorValue(value, type_signature)

    with self.assertRaises(TypeError):
      self.run_sync(value.compute())
Esempio n. 4
0
  def test_returns_value_with_federated_type_at_server(self):
    value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
    type_signature = type_factory.at_server(tf.float32)
    value = federating_executor.FederatingExecutorValue(value, type_signature)

    result = self.run_sync(value.compute())

    self.assertEqual(result, 10.0)
Esempio n. 5
0
  def test_returns_value_with_embedded_value(self):
    value = eager_tf_executor.EagerValue(10.0, None, tf.float32)
    type_signature = computation_types.TensorType(tf.float32)
    value = federating_executor.FederatingExecutorValue(value, type_signature)

    result = self.run_sync(value.compute())

    self.assertEqual(result, 10.0)
Esempio n. 6
0
  def test_returns_value_with_anonymous_tuple_value(self):
    element = eager_tf_executor.EagerValue(10.0, None, tf.float32)
    element_type = computation_types.TensorType(tf.float32)
    names = ['a', 'b', 'c']
    value = anonymous_tuple.AnonymousTuple((n, element) for n in names)
    type_signature = computation_types.NamedTupleType(
        (n, element_type) for n in names)
    value = federating_executor.FederatingExecutorValue(value, type_signature)

    result = self.run_sync(value.compute())

    expected_result = anonymous_tuple.AnonymousTuple((n, 10.0) for n in names)
    self.assertEqual(result, expected_result)