def test_executor_service_create_one_arg_computation_value_and_call(self): ex_factory = executor_test_utils.BasicTestExFactory( eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) @tensorflow_computation.tf_computation(tf.int32) def comp(x): return tf.add(x, 1) value_proto, _ = value_serialization.serialize_value(comp) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) comp_ref = response.value_ref value_proto, _ = value_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) arg_ref = response.value_ref response = env.stub.CreateCall( executor_pb2.CreateCallRequest(executor=env.executor_pb, function_ref=comp_ref, argument_ref=arg_ref)) self.assertIsInstance(response, executor_pb2.CreateCallResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 11) del env
def test_raising_failed_precondition_destroys_executor(self): # A simple clas to mock out the exceptions raised by GRPC. class GrpcFailedPrecondition(grpc.RpcError): pass def code(self): return grpc.StatusCode.FAILED_PRECONDITION def _return_error(): return GrpcFailedPrecondition('Raising failed precondition') raising_ex = executor_test_utils.RaisingExecutor(_return_error) ex_factory = executor_test_utils.BasicTestExFactory(raising_ex) env = TestEnv(ex_factory) value_proto, _ = value_serialization.serialize_value(10, tf.int32) value_id = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) # When CreateValue is forced to run, the executor's error will be raised. We # raise lazily because the service operates future-to-future. with self.assertRaises(grpc.RpcError): env.get_value(value_id.value_ref.id) # The next CreateValue call, or indeed any call, should eagerly raise, # since the executor has been destroyed. with self.assertRaises(grpc.RpcError) as rpc_err: env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertEqual(rpc_err.exception.code(), grpc.StatusCode.FAILED_PRECONDITION) self.assertIn('No executor found', rpc_err.exception.exception().details())
def test_executor_service_create_one_arg_computation_value_and_call(self): env = TestEnv(eager_executor.EagerExecutor()) @computations.tf_computation(tf.int32) def comp(x): return tf.add(x, 1) value_proto, _ = executor_service_utils.serialize_value(comp) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) comp_ref = response.value_ref value_proto, _ = executor_service_utils.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) arg_ref = response.value_ref response = env.stub.CreateCall( executor_pb2.CreateCallRequest(function_ref=comp_ref, argument_ref=arg_ref)) self.assertIsInstance(response, executor_pb2.CreateCallResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 11) del env
async def create_value(self, value, type_spec=None): value_proto, type_spec = (executor_service_utils.serialize_value( value, type_spec)) response = self._stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self)
def test_create_value_reraises_type_error(self, mock_executor_grpc_stub): instance = mock_executor_grpc_stub.return_value instance.CreateValue = mock.Mock(side_effect=TypeError) stub = create_stub() with self.assertRaises(TypeError): stub.create_value(request=executor_pb2.CreateValueRequest())
def test_dispose_does_not_trigger_cleanup(self): class MockFactory(executor_factory.ExecutorFactory, mock.MagicMock): def create_executor(self, *args, **kwargs): return mock.MagicMock() def clean_up_executors(self): return ex_factory = MockFactory() ex_factory.clean_up_executors = mock.MagicMock() env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value( tf.constant(10.0).numpy(), tf.float32) # Create the value response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) # Check that the value appears in the _values map env.get_value_future_directly(value_id) # Dispose of the value dispose_request = executor_pb2.DisposeRequest() dispose_request.value_ref.append(response.value_ref) response = env.stub.Dispose(dispose_request) # We shouldn't be propagating close down the executor stack on Dispose--this # would close the bidi stream and cause a hang in the streaming case with # intermediate aggregation. Python GC takes care of pushing Dispose requests # from the aggregators to the workers. ex_factory.clean_up_executors.assert_not_called()
def test_create_value_returns_value(self, mock_executor_grpc_stub): response = executor_pb2.CreateValueResponse() instance = mock_executor_grpc_stub.return_value instance.CreateValue = mock.Mock(side_effect=[response]) stub = create_stub() result = stub.create_value(request=executor_pb2.CreateValueRequest()) self.assertEqual(result, response)
def test_executor_service_create_and_select_from_tuple(self): ex_factory = executor_test_utils.BasicTestExFactory( eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = value_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) ten_ref = response.value_ref self.assertEqual(env.get_value(ten_ref.id), 10) value_proto, _ = value_serialization.serialize_value(20, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(executor=env.executor_pb, value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) twenty_ref = response.value_ref self.assertEqual(env.get_value(twenty_ref.id), 20) response = env.stub.CreateStruct( executor_pb2.CreateStructRequest( executor=env.executor_pb, element=[ executor_pb2.CreateStructRequest.Element( name='a', value_ref=ten_ref), executor_pb2.CreateStructRequest.Element( name='b', value_ref=twenty_ref) ])) self.assertIsInstance(response, executor_pb2.CreateStructResponse) tuple_ref = response.value_ref self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>') for index, result_val in [(0, 10), (1, 20)]: response = env.stub.CreateSelection( executor_pb2.CreateSelectionRequest(executor=env.executor_pb, source_ref=tuple_ref, index=index)) self.assertIsInstance(response, executor_pb2.CreateSelectionResponse) selection_ref = response.value_ref self.assertEqual(env.get_value(selection_ref.id), result_val) del env
def test_executor_service_create_tensor_value(self): value_proto = executor_service_utils.serialize_value( tf.constant(10.0).numpy(), tf.float32) request = executor_pb2.CreateValueRequest(value=value_proto) response = self._stub.CreateValue(request) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) value = self._extract_value_from_service(value_id) self.assertEqual(value.internal_representation.numpy(), 10.0)
def test_create_value_raises_retryable_error_on_grpc_error_unavailable( self, mock_executor_grpc_stub): instance = mock_executor_grpc_stub.return_value instance.CreateValue = mock.Mock( side_effect=_raise_grpc_error_unavailable) stub = create_stub() with self.assertRaises(executors_errors.RetryableError): stub.create_value(request=executor_pb2.CreateValueRequest())
def test_create_value_reraises_grpc_error(self, mock_executor_grpc_stub): instance = mock_executor_grpc_stub.return_value instance.CreateValue = mock.Mock( side_effect=_raise_non_retryable_grpc_error) stub = create_stub() with self.assertRaises(grpc.RpcError) as context: stub.create_value(request=executor_pb2.CreateValueRequest()) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_executor_service_raises_after_cleanup_without_configuration(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) env.stub.ClearExecutor(executor_pb2.ClearExecutorRequest()) value_proto, _ = executor_serialization.serialize_value( tf.constant(10.0).numpy(), tf.float32) with self.assertRaises(grpc.RpcError): env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto))
def _iterator(): @computations.tf_computation(tf.int32) def comp(x): return tf.add(x, 1) value_proto, _ = executor_serialization.serialize_value(comp) request = executor_pb2.ExecuteRequest( create_value=executor_pb2.CreateValueRequest( value=value_proto)) yield request
def test_executor_service_create_and_select_from_tuple(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) ten_ref = response.value_ref self.assertEqual(env.get_value(ten_ref.id), 10) value_proto, _ = executor_serialization.serialize_value(20, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) twenty_ref = response.value_ref self.assertEqual(env.get_value(twenty_ref.id), 20) response = env.stub.CreateStruct( executor_pb2.CreateStructRequest(element=[ executor_pb2.CreateStructRequest.Element(name='a', value_ref=ten_ref), executor_pb2.CreateStructRequest.Element(name='b', value_ref=twenty_ref) ])) self.assertIsInstance(response, executor_pb2.CreateStructResponse) tuple_ref = response.value_ref self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>') for arg_name, arg_val, result_val in [('name', 'a', 10), ('name', 'b', 20), ('index', 0, 10), ('index', 1, 20)]: response = env.stub.CreateSelection( executor_pb2.CreateSelectionRequest(source_ref=tuple_ref, **{arg_name: arg_val})) self.assertIsInstance(response, executor_pb2.CreateSelectionResponse) selection_ref = response.value_ref self.assertEqual(env.get_value(selection_ref.id), result_val) del env
def test_executor_service_create_tensor_value(self): env = TestEnv(eager_executor.EagerExecutor()) value_proto, _ = executor_service_utils.serialize_value( tf.constant(10.0).numpy(), tf.float32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 10.0) del env
async def create_value(self, value, type_spec=None): @tracing.trace def serialize_value(): return executor_serialization.serialize_value(value, type_spec) value_proto, type_spec = serialize_value() create_value_request = executor_pb2.CreateValueRequest( value=value_proto) response = _request(self._stub.CreateValue, create_value_request) py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self)
def test_executor_service_slowly_create_tensor_value(self): class SlowExecutorValue(executor_value_base.ExecutorValue): def __init__(self, v, t): self._v = v self._t = t @property def type_signature(self): return self._t async def compute(self): return self._v class SlowExecutor(executor_base.Executor): def __init__(self): self.status = 'idle' self.busy = threading.Event() self.done = threading.Event() async def create_value(self, value, type_spec=None): self.status = 'busy' self.busy.set() self.done.wait() self.status = 'done' return SlowExecutorValue(value, type_spec) async def create_call(self, comp, arg=None): raise NotImplementedError async def create_struct(self, elements): raise NotImplementedError async def create_selection(self, source, index=None, name=None): raise NotImplementedError def close(self): pass ex = SlowExecutor() ex_factory = executor_stacks.ResourceManagingExecutorFactory(lambda _: ex) env = TestEnv(ex_factory) self.assertEqual(ex.status, 'idle') value_proto, _ = executor_serialization.serialize_value(10, tf.int32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) ex.busy.wait() self.assertEqual(ex.status, 'busy') ex.done.set() value = env.get_value(response.value_ref.id) self.assertEqual(ex.status, 'done') self.assertEqual(value, 10)
async def create_value(self, value, type_spec=None): value_proto, type_spec = ( executor_service_utils.serialize_value(value, type_spec)) create_value_request = executor_pb2.CreateValueRequest(value=value_proto) if not self._bidi_stream: response = self._stub.CreateValue(create_value_request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_value=create_value_request) )).create_value py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self)
def test_executor_service_create_tensor_value(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) value_proto, _ = executor_serialization.serialize_value( tf.constant(10.0).numpy(), tf.float32) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 10.0) del env
def test_executor_service_create_computation_value(self): @computations.tf_computation def comp(): return tf.constant(10) value_proto = executor_service_utils.serialize_value(comp) request = executor_pb2.CreateValueRequest(value=value_proto) response = self._stub.CreateValue(request) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) value = self._extract_value_from_service(value_id) self.assertTrue(callable(value.internal_representation)) self.assertEqual(value.internal_representation().numpy(), 10.0)
async def create_value(self, value, type_spec=None): @tracing.trace def serialize_value(): return executor_serialization.serialize_value(value, type_spec) value_proto, type_spec = serialize_value() create_value_request = executor_pb2.CreateValueRequest( value=value_proto) if self._bidi_stream is None: response = _request(self._stub.CreateValue, create_value_request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_value=create_value_request) )).create_value py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self)
def test_executor_service_create_no_arg_computation_value_and_call(self): env = TestEnv(eager_executor.EagerExecutor()) @computations.tf_computation def comp(): return tf.constant(10) value_proto, _ = executor_service_utils.serialize_value(comp) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) response = env.stub.CreateCall( executor_pb2.CreateCallRequest(function_ref=response.value_ref)) self.assertIsInstance(response, executor_pb2.CreateCallResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 10) del env
def iterator(self): @computations.tf_computation() def comp(): return 1 value_proto, _ = executor_serialization.serialize_value(comp) request = executor_pb2.ExecuteRequest( create_value=executor_pb2.CreateValueRequest(value=value_proto)) yield request response = self.queue.get() create_call_proto = executor_pb2.CreateCallRequest( function_ref=response.create_value.value_ref, argument_ref=None) request = executor_pb2.ExecuteRequest(create_call=create_call_proto) yield request response = self.queue.get() compute_proto = executor_pb2.ComputeRequest( value_ref=response.create_call.value_ref) request = executor_pb2.ExecuteRequest(compute=compute_proto) yield request
def test_executor_service_create_no_arg_computation_value_and_call(self): ex_factory = executor_stacks.ResourceManagingExecutorFactory( lambda _: eager_tf_executor.EagerTFExecutor()) env = TestEnv(ex_factory) @computations.tf_computation def comp(): return tf.constant(10) value_proto, _ = executor_serialization.serialize_value(comp) response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) response = env.stub.CreateCall( executor_pb2.CreateCallRequest(function_ref=response.value_ref)) self.assertIsInstance(response, executor_pb2.CreateCallResponse) value_id = str(response.value_ref.id) value = env.get_value(value_id) self.assertEqual(value, 10) del env
def test_executor_service_value_unavailable_after_dispose(self): env = TestEnv(eager_tf_executor.EagerTFExecutor()) value_proto, _ = executor_service_utils.serialize_value( tf.constant(10.0).numpy(), tf.float32) # Create the value response = env.stub.CreateValue( executor_pb2.CreateValueRequest(value=value_proto)) self.assertIsInstance(response, executor_pb2.CreateValueResponse) value_id = str(response.value_ref.id) # Check that the value appears in the _values map env.get_value_future_directly(value_id) # Dispose of the value dispose_request = executor_pb2.DisposeRequest() dispose_request.value_ref.append(response.value_ref) response = env.stub.Dispose(dispose_request) self.assertIsInstance(response, executor_pb2.DisposeResponse) # Check that the value is gone from the _values map # get_value_future_directly is used here so that we can catch the # exception rather than having it occur on the GRPC thread. with self.assertRaises(KeyError): env.get_value_future_directly(value_id)