def test_from_proto(self, stub, client): message = job_pb2.Job(id="foo") job = Job._from_proto(message, client=client) assert job._message == message if client is not None: assert job._client == client else: assert isinstance(job._client, Client)
def test_create_full(self, stub): x = types.Int(1) param = types.parameter("bar", types.Int) obj = x + param arguments = {"bar": 2} format = {"type": "pyarrow", "compression": "brotli"} destination = {"type": "download", "result_url": ""} response_message = job_pb2.Job() stub.return_value.CreateJob.return_value = response_message job = Job( obj, format=format, destination=destination, **arguments, ) assert job._message is response_message assert isinstance(job._client, Client) assert isinstance( job._object, types.Function[{ param._name: type(param) }, type(obj)]) assert job._arguments is None rpc = stub.return_value.CreateJob rpc.assert_called_once() assert rpc.call_args.kwargs["timeout"] == Client.DEFAULT_TIMEOUT request = rpc.call_args.args[0] assert isinstance(request, job_pb2.CreateJobRequest) graft = json.loads(request.serialized_graft) assert graft_client.is_function_graft(graft) assert (cereal.deserialize_typespec(request.typespec) is types.Function[{ param._name: type(param) }, type(obj)]) request_args = { name: json.loads(v) for name, v in request.arguments.items() } assert request_args == arguments assert request.geoctx_graft == "" assert request.no_ruster is False assert request.no_cache is False assert request.channel == _channel.__channel__ assert request.trace is False assert request.type == types_pb2.ResultType.Value(type(obj).__name__) assert request.client_version == __version__ assert has_proto_to_user_dict(request.format) == format assert has_proto_to_user_dict(request.destination) == destination
def create_side_effect(req, **kwargs): return job_pb2.Job( id=id_, parameters=req.parameters, serialized_graft=req.serialized_graft, typespec=req.typespec, type=req.type, channel=req.channel, state=job_state, )
def test_load_result_error(self, stub): message = job_pb2.Job( id="foo", status=job_pb2.STATUS_FAILURE, error=job_pb2.JobError(code=errors_pb2.ERROR_INVALID), ) job = Job(message) with pytest.raises(JobInvalid): job._load_result()
def test_init(): message = job_pb2.Job( id="foo", error=job_pb2.JobError(code=errors_pb2.ERROR_DEADLINE, message="bar") ) job = Job(message) e = JobComputeError(job) assert e.code == job.error.code assert e.message == job.error.message assert e.id == job.id
def test_watch(self, stub): id_ = "foo" message = job_pb2.Job(id=id_) job = Job(message) stub.return_value.WatchJob.return_value = [ job_pb2.Job(id=id_, stage=job_pb2.STAGE_PENDING, status=job_pb2.STATUS_UNKNOWN), job_pb2.Job(id=id_, stage=job_pb2.STAGE_RUNNING, status=job_pb2.STATUS_UNKNOWN), job_pb2.Job(id=id_, stage=job_pb2.STAGE_DONE, status=job_pb2.STATUS_SUCCESS), ] state_messages = [state._message for state in job.watch()] assert state_messages == stub.return_value.WatchJob.return_value
def test_wait_success(self, stub): id_ = "foo" destination = user_destination_to_proto({"type": "download"}) message = job_pb2.Job(id=id_, destination=destination) j = Job._from_proto(message) job_state = job_pb2.Job.State(stage=job_pb2.Job.Stage.SUCCEEDED) stub.return_value.WatchJob.return_value = [job_state] j.wait() assert j._message.state.stage == job_state.stage
def test_wait_for_result_success(self, stub): id_ = "foo" message = job_pb2.Job(id=id_) j = Job._from_proto(message) j._load_result = mock.Mock() job_state = job_pb2.Job.State(stage=job_pb2.Job.Stage.SUCCEEDED) stub.return_value.WatchJob.return_value = [job_state] j._wait_for_result() j._load_result.assert_called_once_with() assert j._message.state.stage == job_state.stage
def test_unmarshal_primitive(self, stub): marshalled = (1, 2, True, None) job = Job._from_proto( job_pb2.Job( id="foo", state=job_pb2.Job.State(stage=job_pb2.Job.Stage.SUCCEEDED), type=types_pb2.List, ) ) result = job._unmarshal(marshalled) assert result == list(marshalled)
def create_side_effect(req, **kwargs): return job_pb2.Job( id=id_, parameters=req.parameters, serialized_graft=req.serialized_graft, typespec=req.typespec, type=req.type, channel=req.channel, state=job_state, format=user_format_to_proto(format), destination=user_destination_to_proto(destination), )
def test_cancel(self, stub): message = job_pb2.Job(id="foo") job = Job._from_proto(message) stub.return_value.CancelJob.return_value = job_pb2.CancelJobResponse() job.cancel() stub.return_value.CancelJob.assert_called_with( job_pb2.CancelJobRequest(id=job.id), timeout=Client.DEFAULT_TIMEOUT, metadata=mock.ANY, )
def test_load_result_error(self, stub): message = job_pb2.Job( id="foo", state=job_pb2.Job.State( stage=job_pb2.Job.Stage.FAILED, error=job_pb2.Job.Error(code=errors_pb2.ERROR_INVALID), ), ) job = Job._from_proto(message) with pytest.raises(JobInvalid): job._load_result()
def test_wait_for_result_timeout(self, stub): id_ = "foo" message = job_pb2.Job(id=id_) j = Job._from_proto(message) job_state = job_pb2.Job.State(stage=job_pb2.Job.Stage.QUEUED) stub.return_value.WatchJob.return_value = [job_state] with pytest.raises(JobTimeoutError): j._wait_for_result(1e-4) assert j._message.state.stage == job_state.stage stub.return_value.WatchJob.assert_called()
def test_create(self, stub): obj = types.Int(1) parameters = {"foo": types.Str("bar")} typespec = cereal.serialize_typespec(type(obj)) format_proto = user_format_to_proto({ "type": "pyarrow", "compression": "brotli" }) destination_proto = user_destination_to_proto({"type": "download"}) create_job_request_message = job_pb2.CreateJobRequest( parameters=json.dumps(parameters_to_grafts(**parameters)), serialized_graft=json.dumps(obj.graft), typespec=typespec, type=types_pb2.ResultType.Value( cereal.typespec_to_unmarshal_str(typespec)), format=format_proto, destination=destination_proto, no_cache=False, channel=_channel.__channel__, ) message = job_pb2.Job( id="foo", parameters=create_job_request_message.parameters, serialized_graft=create_job_request_message.serialized_graft, typespec=create_job_request_message.typespec, type=create_job_request_message.type, format=create_job_request_message.format, destination=create_job_request_message.destination, no_cache=create_job_request_message.no_cache, channel=create_job_request_message.channel, ) stub.return_value.CreateJob.return_value = message job = Job( obj, parameters, format={ "type": "pyarrow", "compression": "brotli" }, destination="download", ) stub.return_value.CreateJob.assert_called_once_with( create_job_request_message, timeout=Client.DEFAULT_TIMEOUT, metadata=(("x-wf-channel", create_job_request_message.channel), ), ) assert job._message is message
def test_create_cache(self, stub, cache): id_ = "foo" obj = types.Int(1) stub.return_value.CreateJob.side_effect = lambda req, **kwargs: job_pb2.Job( id=id_, no_cache=req.no_cache) job = Job(obj, cache=cache) stub.return_value.CreateJob.assert_called_once() assert stub.return_value.CreateJob.call_args[0][0].no_cache == ( not cache) assert job.cache_enabled == cache
def test_init(): message = job_pb2.Job( id="foo", state=job_pb2.Job.State(error=job_pb2.Job.Error( code=errors_pb2.ERROR_DEADLINE, message="bar")), ) job = Job._from_proto(message) e = JobComputeError(job) assert e.code == job.error.code assert e.message == job.error.message assert e.id == job.id
def test_create_ruster(self, stub, ruster): rpc = stub.return_value.CreateJob rpc.return_value = job_pb2.Job() Job(types.Int(1), _ruster=ruster) rpc.assert_called_once() request_no_ruster = rpc.call_args.args[0].no_ruster if ruster is False: assert request_no_ruster is True else: assert request_no_ruster is False
def test_error_with_version_mismatch(self, stub): job = Job._from_proto(job_pb2.Job(client_version="foo")) with pytest.raises(NotImplementedError, match="was created by client version 'foo'"): job.arguments with pytest.raises(NotImplementedError, match="was created by client version 'foo'"): job.object with pytest.raises(NotImplementedError, match="was created by client version 'foo'"): job.resubmit()
def test_get(self, stub, client): id_ = "foo" message = job_pb2.Job(id=id_) stub.return_value.GetJob.return_value = message job = Job.get(id_, client=client) assert job._message == message stub.return_value.GetJob.assert_called_with( job_pb2.GetJobRequest(id=id_), timeout=Client.DEFAULT_TIMEOUT) if client is not None: assert job._client == client else: assert isinstance(job._client, Client)
def test_wait_for_result_failure(self, stub): id_ = "foo" message = job_pb2.Job(id=id_) j = Job._from_proto(message) job_state = job_pb2.Job.State( stage=job_pb2.Job.Stage.FAILED, error=job_pb2.Job.Error(code=errors_pb2.ERROR_UNKNOWN), ) stub.return_value.WatchJob.return_value = [job_state] with pytest.raises(JobComputeError): j._wait_for_result() assert j._message.state.stage == job_state.stage
def test_wait_for_result_terminated(self, stub): id_ = "foo" message = job_pb2.Job(id=id_) j = Job._from_proto(message) job_state = job_pb2.Job.State( stage=job_pb2.Job.Stage.FAILED, error=job_pb2.Job.Error(code=errors_pb2.ERROR_TERMINATED), ) stub.return_value.WatchJob.return_value = [job_state] with pytest.raises(JobTerminated): j._wait_for_result() assert j._message.state.stage == job_state.stage
def test_wait_terminated(self, stub): id_ = "foo" destination = user_destination_to_proto({"type": "download"}) message = job_pb2.Job(id=id_, destination=destination) j = Job._from_proto(message) job_state = job_pb2.Job.State( stage=job_pb2.Job.Stage.FAILED, error=job_pb2.Job.Error(code=errors_pb2.ERROR_TERMINATED), ) stub.return_value.WatchJob.return_value = [job_state] with pytest.raises(JobTerminated): j.wait() assert j._message.state.stage == job_state.stage
def test_wait_for_result_timeout(self, stub): id_ = "foo" status = job_pb2.STATUS_UNKNOWN message = job_pb2.Job(id=id_, status=status, stage=job_pb2.STAGE_PENDING) j = Job(message) stub.return_value.GetJob.return_value = message with pytest.raises(Exception): # TODO(justin) fix exception type j._wait_for_result(1e-4) assert j._message.status == status stub.return_value.GetJob.assert_called()
def create_side_effect(req, **kwargs): return job_pb2.Job( id=id_, serialized_graft=req.serialized_graft, typespec=req.typespec, arguments=req.arguments, geoctx_graft=req.geoctx_graft, no_ruster=req.no_ruster, channel=req.channel, client_version=__version__, expires_timestamp=expires_timestamp, no_cache=req.no_cache, trace=req.trace, state=job_state, type=req.type, format=user_format_to_proto(format), destination=user_destination_to_proto(destination), )
def test_watch(self, stub): id_ = "foo" message = job_pb2.Job(id=id_) job = Job._from_proto(message) stub.return_value.WatchJob.return_value = [ job_pb2.Job.State(stage=job_pb2.Job.Stage.QUEUED), job_pb2.Job.State(stage=job_pb2.Job.Stage.RUNNING), job_pb2.Job.State(stage=job_pb2.Job.Stage.SUCCEEDED), ] state_messages = [] for job_ in job.watch(): state = job_pb2.Job.State() state.CopyFrom(job_._message.state) state_messages.append(state) assert state_messages == stub.return_value.WatchJob.return_value
def test_wait_timeout(self, stub): id_ = "foo" destination = user_destination_to_proto({"type": "download"}) message = job_pb2.Job(id=id_, destination=destination) j = Job._from_proto(message) job_state = job_pb2.Job.State(stage=job_pb2.Job.Stage.QUEUED) def side_effect(*args, **kwargs): yield job_state raise MockRpcError(grpc.StatusCode.DEADLINE_EXCEEDED) stub.return_value.WatchJob.side_effect = side_effect with pytest.raises(JobTimeoutError): j.wait(timeout=1) stub.return_value.WatchJob.assert_called() assert j._message.state.stage == job_state.stage
def test_result_to_file(self, stub, file_path, tmpdir): format_proto = user_format_to_proto("json") destination_proto = user_destination_to_proto("download") destination_proto.download.result_url = ( "https://storage.googleapis.com/dl-compute-dev-results") job = Job._from_proto( job_pb2.Job( id="foo", state=job_pb2.Job.State(stage=job_pb2.Job.Stage.SUCCEEDED), format=format_proto, destination=destination_proto, )) result = [1, 2, 3, 4] responses.add( responses.GET, job.url, body=json.dumps(result), headers={"x-goog-stored-content-encoding": "application/json"}, status=200, stream=True, ) path = tmpdir.join("test.json") file_arg = str(path) if file_path else path.open("wb") job.result_to_file(file_arg) if not file_path: assert not file_arg.closed file_arg.flush() with open(str(path), "r") as f: assert result == json.load(f) if not file_path: file_arg.close()
def test_download_result(self, stub): job = Job._from_proto( job_pb2.Job( id="foo", state=job_pb2.Job.State(stage=job_pb2.Job.Stage.SUCCEEDED), ) ) result = {} buffer = pa.serialize(result, context=serialization_context).to_buffer() codec = "lz4" responses.add( responses.GET, Job.BUCKET_PREFIX.format(job.id), body=pa.compress(buffer, codec=codec, asbytes=True), headers={ "x-goog-meta-codec": codec, "x-goog-meta-decompressed_size": str(len(buffer)), }, status=200, ) assert job._download_result() == result
def test_unmarshal_image(self, stub): marshalled = { "ndarray": { "red": [] }, "properties": { "foo": "bar", "geometry": { "type": "Point", "coordinates": [0, 0] }, }, "bandinfo": { "red": {} }, "geocontext": { "geometry": None, "crs": "EPSG:4326", "bounds": (-98, 40, -90, 44), }, } job = Job( job_pb2.Job(id="foo", status=job_pb2.STATUS_SUCCESS, type=types_pb2.Image)) result = job._unmarshal(marshalled) # NOTE(gabe): we check the class name, versus `isinstance(result, results.ImageResult)`, # because importing results in this test would register its unmarshallers, # and part of what we're testing for is that the unmarshallers are getting registered correctly. assert result.__class__.__name__ == "ImageResult" assert result.ndarray == marshalled["ndarray"] assert result.properties == marshalled["properties"] assert result.bandinfo == marshalled["bandinfo"] assert result.geocontext == marshalled["geocontext"]
def build(cls, proxy_object, parameters, channel=None, client=None): """ Build a new `Job` for computing a proxy object under certain parameters. Does not actually trigger computation; call `Job.execute` on the result to do so. Parameters ---------- proxy_object: Proxytype Proxy object to compute parameters: dict[str, Proxytype] Python dictionary of parameter names and values channel: str or None, optional Channel name to submit the `Job` to. If None, uses the default channel for this client (``descarteslabs.workflows.__channel__``). Channels are different versions of the backend, to allow for feature changes without breaking existing code. Not all clients are compatible with all channels. This client is only guaranteed to work with its default channel, whose name can be found under ``descarteslabs.workflows.__channel__``. client : `.workflows.client.Client`, optional Allows you to use a specific client instance with non-default auth and parameters Returns ------- Job The job waiting to be executed. Example ------- >>> from descarteslabs.workflows import Job, Int, parameter >>> my_int = Int(1) + parameter("other_int", Int) >>> job = Job.build(my_int, {"other_int": 10}) >>> # the job does not execute until `.execute` is called >>> job.execute() # doctest: +SKIP """ if channel is None: # NOTE(gabe): we look up the variable from the `_channel` package here, # rather than importing it directly at the top, # so it can easily be changed during an interactive session. channel = _channel.__channel__ if client is None: client = Client() typespec = serialize_typespec(type(proxy_object)) result_type = _typespec_to_unmarshal_str(typespec) # ^ this also preemptively checks whether the result type is something we'll know how to unmarshal parameters = parameters_to_grafts(**parameters) message = job_pb2.Job( parameters=json.dumps(parameters), serialized_graft=json.dumps(proxy_object.graft), typespec=typespec, type=types_pb2.ResultType.Value(result_type), channel=channel, ) instance = cls(message, client) instance._object = proxy_object return instance