def test_raises_type_error_with_value_and_bad_type(self): value = eager_tf_executor.EagerValue(10.0, None, tf.float32) bad_type_signature = None with self.assertRaises(TypeError): federating_executor.FederatingExecutorValue( value, bad_type_signature)
def test_raises_runtime_error_with_unsupported_value_or_type(self): value = 10.0 type_signature = computation_types.TensorType(tf.float32) value = federating_executor.FederatingExecutorValue(value, type_signature) with self.assertRaises(RuntimeError): self.run_sync(value.compute())
def test_raises_type_error_with_unembedded_federated_type(self): value = [10.0, 11.0, 12.0] type_signature = type_factory.at_clients(tf.float32) value = federating_executor.FederatingExecutorValue(value, type_signature) with self.assertRaises(TypeError): self.run_sync(value.compute())
def test_returns_value_with_federated_type_at_server(self): value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)] type_signature = type_factory.at_server(tf.float32) value = federating_executor.FederatingExecutorValue(value, type_signature) result = self.run_sync(value.compute()) self.assertEqual(result, 10.0)
def test_returns_value_with_embedded_value(self): value = eager_tf_executor.EagerValue(10.0, None, tf.float32) type_signature = computation_types.TensorType(tf.float32) value = federating_executor.FederatingExecutorValue(value, type_signature) result = self.run_sync(value.compute()) self.assertEqual(result, 10.0)
def test_returns_value_with_anonymous_tuple_value(self): element = eager_tf_executor.EagerValue(10.0, None, tf.float32) element_type = computation_types.TensorType(tf.float32) names = ['a', 'b', 'c'] value = anonymous_tuple.AnonymousTuple((n, element) for n in names) type_signature = computation_types.NamedTupleType( (n, element_type) for n in names) value = federating_executor.FederatingExecutorValue(value, type_signature) result = self.run_sync(value.compute()) expected_result = anonymous_tuple.AnonymousTuple((n, 10.0) for n in names) self.assertEqual(result, expected_result)