def test_task_runner_sets_task_name(monkeypatch, cloud_settings): client = MagicMock() monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) client.set_task_run_name = MagicMock() task = Task(name="test", task_run_name="asdf") runner = CloudTaskRunner(task=task) runner.task_run_id = "id" with prefect.context(): assert prefect.context.get("task_run_name") is None runner.set_task_run_name(task_inputs={}) assert client.set_task_run_name.called assert client.set_task_run_name.call_args[1]["name"] == "asdf" assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" assert prefect.context.get("task_run_name") == "asdf" task = Task(name="test", task_run_name="{map_index}") runner = CloudTaskRunner(task=task) runner.task_run_id = "id" class Temp: value = 100 with prefect.context(): assert prefect.context.get("task_run_name") is None runner.set_task_run_name(task_inputs={"map_index": Temp()}) assert client.set_task_run_name.called assert client.set_task_run_name.call_args[1]["name"] == "100" assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" assert prefect.context.get("task_run_name") == "100" task = Task(name="test", task_run_name=lambda **kwargs: "name") runner = CloudTaskRunner(task=task) runner.task_run_id = "id" with prefect.context(): assert prefect.context.get("task_run_name") is None runner.set_task_run_name(task_inputs={}) assert client.set_task_run_name.called assert client.set_task_run_name.call_args[1]["name"] == "name" assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" assert prefect.context.get("task_run_name") == "name"
def test_task_runner_does_not_have_heartbeat_if_disabled( self, monkeypatch): client = MagicMock() monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) client.graphql.return_value.data.flow_run_by_pk.flow.settings = dict( heartbeat_enabled=False) runner = CloudTaskRunner(task=Task()) runner.task_run_id = "foo" res = runner._heartbeat() assert res is False
def test_heartbeat_traps_errors_caused_by_client(self, caplog, monkeypatch): client = MagicMock(graphql=MagicMock(side_effect=SyntaxError)) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) runner = CloudTaskRunner(task=Task(name="bad")) runner.task_run_id = None res = runner._heartbeat() assert res is False log = caplog.records[0] assert log.levelname == "ERROR" assert "Heartbeat failed for Task 'bad'" in log.message
def test_heartbeat_traps_errors_caused_by_client(self, monkeypatch): client = MagicMock(update_task_run_heartbeat=MagicMock( side_effect=SyntaxError)) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) runner = CloudTaskRunner(task=Task(name="bad")) runner.task_run_id = None with pytest.warns(UserWarning) as warning: res = runner._heartbeat() assert res is None assert client.update_task_run_heartbeat.called w = warning.pop() assert "Heartbeat failed for Task 'bad'" in repr(w.message)
def test_task_runner_heartbeat_sets_command(self, monkeypatch): client = MagicMock() monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) client.graphql.return_value.data.flow_run_by_pk.flow.settings = dict( disable_heartbeat=False) runner = CloudTaskRunner(task=Task()) runner.task_run_id = "foo" res = runner._heartbeat() assert res is True assert runner.task_run_id == "foo" assert runner.heartbeat_cmd == [ "prefect", "heartbeat", "task-run", "-i", "foo" ]