Example #1
0
def test_client_deploy(patch_post, compressed):
    if compressed:
        response = {
            "data": {
                "project": [{
                    "id": "proj-id"
                }],
                "createFlowFromCompressedString": {
                    "id": "long-id"
                },
            }
        }
    else:
        response = {
            "data": {
                "project": [{
                    "id": "proj-id"
                }],
                "createFlow": {
                    "id": "long-id"
                }
            }
        }
    patch_post(response)

    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()
    flow = prefect.Flow(name="test",
                        storage=prefect.environments.storage.Memory())
    flow_id = client.deploy(flow,
                            project_name="my-default-project",
                            compressed=compressed)
    assert flow_id == "long-id"
Example #2
0
    def test_graphql_uses_access_token_after_login(self, patch_post):
        tenant_id = str(uuid.uuid4())
        post = patch_post({
            "data": {
                "tenant": [{
                    "id": tenant_id
                }],
                "switchTenant": {
                    "accessToken": "ACCESS_TOKEN",
                    "expiresIn": 600,
                    "refreshToken": "REFRESH_TOKEN",
                },
            }
        })
        client = Client(api_token="api")
        client.graphql({})
        assert client.get_auth_token() == "api"
        assert post.call_args[1]["headers"] == dict(Authorization="Bearer api")

        client.login_to_tenant(tenant_id=tenant_id)
        client.graphql({})
        assert client.get_auth_token() == "ACCESS_TOKEN"
        assert post.call_args[1]["headers"] == dict(
            Authorization="Bearer ACCESS_TOKEN")
Example #3
0
def test_set_task_run_state_responds_to_status(patch_post):
    response = {
        "data": {
            "setTaskRunStates": {
                "states": [{
                    "status": "QUEUED"
                }]
            }
        }
    }
    post = patch_post(response)
    state = Pending()

    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()
    result = client.set_task_run_state(task_run_id="76-salt",
                                       version=0,
                                       state=state)

    assert result.is_queued()
    assert result.state is None  # caller should set this
Example #4
0
def test_get_task_run_info_with_error(monkeypatch):
    response = {
        "data": {
            "getOrCreateTaskRun": None
        },
        "errors": [{
            "message": "something went wrong"
        }],
    }
    post = MagicMock(return_value=MagicMock(json=MagicMock(
        return_value=response)))
    monkeypatch.setattr("requests.post", post)
    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()

    with pytest.raises(ClientError) as exc:
        client.get_task_run_info(flow_run_id="74-salt",
                                 task_id="72-salt",
                                 map_index=None)

    assert "something went wrong" in str(exc.value)
Example #5
0
def test_set_task_run_state_with_error(monkeypatch):
    response = {
        "data": {
            "setTaskRunState": None
        },
        "errors": [{
            "message": "something went wrong"
        }],
    }
    post = MagicMock(return_value=MagicMock(json=MagicMock(
        return_value=response)))

    monkeypatch.setattr("requests.post", post)
    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()

    with pytest.raises(ClientError) as exc:
        client.set_task_run_state(task_run_id="76-salt",
                                  version=0,
                                  state=Pending())
    assert "something went wrong" in str(exc.value)
Example #6
0
def test_get_default_tenant_slug_as_user(patch_post):
    response = {
        "data": {
            "user": [{
                "default_membership": {
                    "tenant": {
                        "slug": "tslug"
                    }
                }
            }]
        }
    }

    patch_post(response)

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
    }):
        client = Client()
        slug = client.get_default_tenant_slug()

        assert slug == "tslug"
Example #7
0
def test_get_task_run_info(patch_post):
    response = {
        "getOrCreateTaskRun": {
            "task_run": {
                "id": "772bd9ee-40d7-479c-9839-4ab3a793cabd",
                "version": 0,
                "serialized_state": {
                    "type": "Pending",
                    "_result": {
                        "type": "SafeResult",
                        "value": "42",
                        "result_handler": {"type": "JSONResultHandler"},
                    },
                    "message": None,
                    "__version__": "0.3.3+310.gd19b9b7.dirty",
                    "cached_inputs": None,
                },
                "task": {"slug": "slug"},
            }
        }
    }

    post = patch_post(dict(data=response))
    with set_temporary_config(
        {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
    ):
        client = Client()
    result = client.get_task_run_info(
        flow_run_id="74-salt", task_id="72-salt", map_index=None
    )
    assert isinstance(result, TaskRunInfoResult)
    assert isinstance(result.state, Pending)
    assert result.state.result == "42"
    assert result.state.message is None
    assert result.id == "772bd9ee-40d7-479c-9839-4ab3a793cabd"
    assert result.version == 0
Example #8
0
def test_set_task_run_state_serializes(patch_post):
    response = {
        "data": {
            "set_task_run_states": {
                "states": [{
                    "status": "SUCCESS"
                }]
            }
        }
    }
    post = patch_post(response)

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
    }):
        client = Client()

    res = SafeResult(lambda: None, result_handler=None)
    with pytest.raises(marshmallow.exceptions.ValidationError):
        client.set_task_run_state(task_run_id="76-salt",
                                  version=0,
                                  state=Pending(result=res))
Example #9
0
def test_set_task_run_state(patch_post):
    response = {
        "data": {
            "set_task_run_states": {
                "states": [{
                    "status": "SUCCESS"
                }]
            }
        }
    }
    post = patch_post(response)
    state = Pending()

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
    }):
        client = Client()
    result = client.set_task_run_state(task_run_id="76-salt",
                                       version=0,
                                       state=state)

    assert result is state
Example #10
0
def test_get_flow_run_state(patch_posts, cloud_api, runner_token):
    query_resp = {
        "flow_run_by_pk": {
            "serialized_state": {
                "type": "Pending",
                "_result": {
                    "type": "SafeResult",
                    "value": "42",
                    "result_handler": {"type": "JSONResultHandler"},
                },
                "message": None,
                "__version__": "0.3.3+310.gd19b9b7.dirty",
                "cached_inputs": None,
            },
        }
    }

    post = patch_posts([dict(data=query_resp)])

    client = Client()
    state = client.get_flow_run_state(flow_run_id="72-salt")
    assert isinstance(state, Pending)
    assert state.result == "42"
    assert state.message is None
Example #11
0
    def test_login_to_client_doesnt_reload_active_tenant_when_token_isnt_loaded(
        self, patch_post
    ):
        tenant_id = str(uuid.uuid4())
        post = patch_post(
            {
                "data": {
                    "tenant": [{"id": tenant_id}],
                    "switchTenant": {
                        "accessToken": "ACCESS_TOKEN",
                        "expiresIn": 600,
                        "refreshToken": "REFRESH_TOKEN",
                    },
                }
            }
        )

        client = Client(api_token="abc")
        assert client._active_tenant_id is None
        client.login_to_tenant(tenant_id=tenant_id)
        assert client._active_tenant_id == tenant_id

        # new client doesn't load the active tenant because there's no api token loaded
        assert Client()._active_tenant_id is None
    def test_client_token_priotizes_config_over_file(selfmonkeypatch,
                                                     cloud_api):
        with tempfile.TemporaryDirectory() as tmp:
            with set_temporary_config({
                    "home_dir": tmp,
                    "cloud.graphql": "xyz",
                    "cloud.auth_token": "CONFIG_TOKEN",
            }):
                path = Path(tmp) / "client" / "xyz" / "settings.toml"
                path.parent.mkdir(parents=True)
                with path.open("w") as f:
                    toml.dump(dict(api_token="FILE_TOKEN"), f)

                client = Client()
        assert client._api_token == "CONFIG_TOKEN"
Example #13
0
def test_write_log_with_error(monkeypatch):
    response = {
        "data": {
            "writeRunLog": None
        },
        "errors": [{
            "message": "something went wrong"
        }],
    }
    post = MagicMock(return_value=MagicMock(json=MagicMock(
        return_value=response)))
    session = MagicMock()
    session.return_value.post = post
    monkeypatch.setattr("requests.Session", session)

    with set_temporary_config({
            "cloud.graphql": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()

    with pytest.raises(ClientError) as exc:
        client.write_run_log(flow_run_id="1")
    assert "something went wrong" in str(exc.value)
Example #14
0
def create_prefect_flow_run(flow_name: str, project_name: str, task_refs: List,
                            params: Mapping) -> str:
    """Creates new prefect flow run for given flow id, parameters, task references
    and API server URL to send GraphQL requests to.
    Returns results value and state from a Prefect flow run.
    """

    try:
        flow_run = StartFlowRun(flow_name=flow_name,
                                project_name=project_name,
                                parameters=params)
        flow_run_id = flow_run.run()
        client = Client()
        while True:
            time.sleep(10)
            flow_run_info = client.get_flow_run_info(flow_run_id)
            flow_state = flow_run_info.state
            task_runs_info = flow_run_info.task_runs
            if flow_state.is_finished():
                task_res_locs = {}
                for task_run in task_runs_info:
                    # Return ref if ref string is a substring of any task slug
                    ref = next((ref_str for ref_str in task_refs
                                if ref_str in task_run.task_slug), None)
                    if ref:
                        task_id = task_run.id
                        task_state = client.get_task_run_state(task_id)
                        task_res_locs[ref] = task_state._result.location
                task_results = {}
                for ref, loc in task_res_locs.items():
                    local_res = LocalResult()
                    result = local_res.read(loc)
                    task_results[ref] = result.value
                return task_results, flow_state, task_res_locs
    except ValueError as err:
        raise err
    def test_save_local_settings(self, cloud_api):
        with tempfile.TemporaryDirectory() as tmp:
            with set_temporary_config({
                    "home_dir": tmp,
                    "cloud.graphql": "xyz"
            }):
                path = Path(tmp) / "client" / "xyz" / "settings.toml"

                client = Client(api_token="a")
                client.save_api_token()
                with path.open("r") as f:
                    assert toml.load(f)["api_token"] == "a"

                client = Client(api_token="b")
                client.save_api_token()
                with path.open("r") as f:
                    assert toml.load(f)["api_token"] == "b"
Example #16
0
def test_client_register(patch_post, compressed, monkeypatch, tmpdir):
    if compressed:
        response = {
            "data": {
                "project": [{"id": "proj-id"}],
                "create_flow_from_compressed_string": {"id": "long-id"},
            }
        }
    else:
        response = {
            "data": {"project": [{"id": "proj-id"}], "create_flow": {"id": "long-id"}}
        }
    patch_post(response)

    monkeypatch.setattr(
        "prefect.client.Client.get_default_tenant_slug", MagicMock(return_value="tslug")
    )

    with set_temporary_config(
        {
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
        }
    ):
        client = Client()
    flow = prefect.Flow(name="test", storage=prefect.environments.storage.Local(tmpdir))
    flow.result = flow.storage.result

    flow_id = client.register(
        flow,
        project_name="my-default-project",
        compressed=compressed,
        version_group_id=str(uuid.uuid4()),
    )
    assert flow_id == "long-id"
Example #17
0
    def test_load_local_api_token_is_called_when_the_client_is_initialized_without_token(
        self, cloud_api
    ):
        with tempfile.TemporaryDirectory() as tmp:
            with set_temporary_config({"home_dir": tmp}):
                client = Client(api_token="a")
                client.save_api_token()

                client = Client()
                assert client._api_token == "a"
Example #18
0
def test_logout_removes_token(monkeypatch):
    with tempfile.NamedTemporaryFile(delete=False) as f:
        monkeypatch.setattr("prefect.client.Client.local_token_path", f.name)

        client = Client()

        client.login(api_token="a")
        assert f.read() == b"a"

    client.logout()
    assert not os.path.exists(f.name)
Example #19
0
 def test_get_auth_token_doesnt_refresh_if_refresh_token_and_future_expiration(
         self, monkeypatch):
     refresh_token = MagicMock()
     monkeypatch.setattr("prefect.Client._refresh_access_token",
                         refresh_token)
     client = Client(api_token="api")
     client._access_token = "access"
     client._refresh_token = "refresh"
     client._access_token_expires_at = pendulum.now().add(minutes=10)
     assert client.get_auth_token() == "access"
     refresh_token.assert_not_called()
    def test_client_infers_correct_tenant_if_a_token_is_not_user_scoped(
            self, patch_posts, cloud_api):
        patch_posts([
            # First, raise an UNAUTHENTICATED error
            {
                "errors": [{
                    "message": "",
                    "locations": [],
                    "path": ["tenant"],
                    "extensions": {
                        "code": "UNAUTHENTICATED"
                    },
                }]
            },
            # Then, return a tenant id
            {
                "data": {
                    "tenant": [{
                        "id": "tenant-id"
                    }]
                }
            },
        ])

        # create a client just so we can use its settings methods to store settings
        disk_tenant = str(uuid.uuid4())
        client = Client()
        client._save_local_settings(
            dict(api_token="API_TOKEN", active_tenant_id=disk_tenant))

        # this initialization will fail to login to the active tenant then load the
        # correct tenant from the API
        client = Client(api_token="API_TOKEN")
        client._init_tenant()
        assert client._tenant_id == "tenant-id"

        # Disk is unchanged
        settings = client._load_local_settings()
        assert settings["active_tenant_id"] == disk_tenant
Example #21
0
def test_login_writes_token(monkeypatch):
    with tempfile.NamedTemporaryFile() as f:
        monkeypatch.setattr("prefect.client.Client.local_token_path", f.name)

        client = Client()

        client.login(api_token="a")
        assert f.read() == b"a"

        f.seek(0)

        client.login(api_token="b")
        assert f.read() == b"b"
Example #22
0
def test_client_logs_out_and_deletes_auth_token(monkeypatch):
    post = MagicMock(return_value=MagicMock(
        ok=True, json=MagicMock(return_value=dict(token="secrettoken"))))
    monkeypatch.setattr("requests.post", post)
    with set_temporary_config({"cloud.graphql": "http://my-cloud.foo"}):
        client = Client()
    client.login("*****@*****.**", "1234")
    token_path = os.path.expanduser("~/.prefect/.credentials/auth_token")
    assert os.path.exists(token_path)
    with open(token_path, "r") as f:
        assert f.read() == "secrettoken"
    client.logout()
    assert not os.path.exists(token_path)
Example #23
0
def test_client_attached_headers(monkeypatch, cloud_api):
    get = MagicMock()
    session = MagicMock()
    session.return_value.get = get
    monkeypatch.setattr("requests.Session", session)
    with set_temporary_config({"cloud.auth_token": "secret_token"}):
        client = Client()
        assert client._attached_headers == {}

        client.attach_headers({"1": "1"})
        assert client._attached_headers == {"1": "1"}

        client.attach_headers({"2": "2"})
        assert client._attached_headers == {"1": "1", "2": "2"}
Example #24
0
    def test_client_clears_active_tenant_if_login_fails_on_initialization(
            self, patch_post):
        post = patch_post({
            "errors": [{
                "message": "",
                "locations": [],
                "path": ["tenant"],
                "extensions": {
                    "code": "UNAUTHENTICATED"
                },
            }]
        })

        # create a client just so we can use its settings methods to store settings
        client = Client()
        settings = client._load_local_settings()
        settings.update(api_token="API_TOKEN",
                        active_tenant_id=str(uuid.uuid4()))
        client._save_local_settings(settings)

        # this initialization will fail with the patched error
        client = Client()
        settings = client._load_local_settings()
        assert "active_tenant_id" not in settings
Example #25
0
 def test_login_uses_api_token_when_access_token_is_set(self, patch_post):
     tenant_id = str(uuid.uuid4())
     post = patch_post(
         {
             "data": {
                 "tenant": [{"id": tenant_id}],
                 "switch_tenant": {
                     "access_token": "ACCESS_TOKEN",
                     "expires_at": "2100-01-01",
                     "refresh_token": "REFRESH_TOKEN",
                 },
             }
         }
     )
     client = Client(api_token="api")
     client._access_token = "access"
     client.login_to_tenant(tenant_id=tenant_id)
     assert client.get_auth_token() == "ACCESS_TOKEN"
     assert post.call_args[1]["headers"] == {
         "Authorization": "Bearer api",
         "X-PREFECT-CORE-VERSION": str(prefect.__version__),
     }
Example #26
0
def test_get_flow_run_info_with_nontrivial_payloads(patch_post):
    response = {
        "flow_run_by_pk": {
            "id":
            "da344768-5f5d-4eaf-9bca-83815617f713",
            "flow_id":
            "da344768-5f5d-4eaf-9bca-83815617f713",
            "name":
            "flow-run-name",
            "version":
            0,
            "parameters": {
                "x": {
                    "deep": {
                        "nested": 5
                    }
                }
            },
            "context": {
                "my_val": "test"
            },
            "scheduled_start_time":
            "2019-01-25T19:15:58.632412+00:00",
            "serialized_state": {
                "type": "Pending",
                "_result": {
                    "type": "SafeResult",
                    "value": "42",
                    "result_handler": {
                        "type": "JSONResultHandler"
                    },
                },
                "message": None,
                "__version__": "0.3.3+309.gf1db024",
                "cached_inputs": None,
            },
            "task_runs": [{
                "id": "da344768-5f5d-4eaf-9bca-83815617f713",
                "task": {
                    "id": "da344768-5f5d-4eaf-9bca-83815617f713",
                    "slug": "da344768-5f5d-4eaf-9bca-83815617f713",
                },
                "version": 0,
                "serialized_state": {
                    "type": "Pending",
                    "result": None,
                    "message": None,
                    "__version__": "0.3.3+309.gf1db024",
                    "cached_inputs": None,
                },
            }],
        }
    }
    post = patch_post(dict(data=response))

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
    }):
        client = Client()
    result = client.get_flow_run_info(flow_run_id="74-salt")
    assert isinstance(result, FlowRunInfoResult)
    assert isinstance(result.scheduled_start_time, datetime.datetime)
    assert result.scheduled_start_time.minute == 15
    assert result.scheduled_start_time.year == 2019
    assert isinstance(result.state, Pending)
    assert result.state.result == "42"
    assert result.state.message is None
    assert result.version == 0
    assert isinstance(result.parameters, dict)
    assert result.parameters["x"]["deep"]["nested"] == 5
    # ensures all sub-dictionaries are actually dictionaries
    assert json.loads(json.dumps(result.parameters)) == result.parameters
    assert isinstance(result.context, dict)
    assert result.context["my_val"] == "test"
Example #27
0
def test_client_token_path_depends_on_graphql_server():
    assert Client(graphql_server="a").local_token_path == os.path.expanduser(
        "~/.prefect/tokens/a")

    assert Client(graphql_server="b").local_token_path == os.path.expanduser(
        "~/.prefect/tokens/b")
Example #28
0
 def test_client_token_priotizes_arg_over_config(self):
     with set_temporary_config({"cloud.auth_token": "CONFIG_TOKEN"}):
         client = Client(api_token="ARG_TOKEN")
     assert client._api_token == "ARG_TOKEN"
Example #29
0
def test_cloud_task_runners_submitted_to_remote_machines_respect_original_config(
    monkeypatch, ):
    """
    This test is meant to simulate the behavior of running a Cloud Flow against an external
    cluster which has _not_ been configured for Prefect.  The idea is that the configuration
    settings which were present on the original machine are respected in the remote job, reflected
    here by having the CloudHandler called during logging and the special values present in context.
    """
    class CustomFlowRunner(CloudFlowRunner):
        def run_task(self, *args, **kwargs):
            with prefect.utilities.configuration.set_temporary_config({
                    "logging.log_to_cloud":
                    False,
                    "cloud.auth_token":
                    ""
            }):
                return super().run_task(*args, **kwargs)

    @prefect.task(result_handler=JSONResultHandler())
    def log_stuff():
        logger = prefect.context.get("logger")
        logger.critical("important log right here")
        return (
            prefect.context.config.special_key,
            prefect.context.config.cloud.auth_token,
        )

    calls = []

    class Client(MagicMock):
        def write_run_logs(self, *args, **kwargs):
            calls.append(args)

        def set_task_run_state(self, *args, **kwargs):
            return kwargs.get("state")

        def get_flow_run_info(self, *args, **kwargs):
            return MagicMock(
                id="flow_run_id",
                task_runs=[MagicMock(task_slug=log_stuff.slug, id="TESTME")],
            )

    monkeypatch.setattr("prefect.client.Client", Client)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", Client)
    monkeypatch.setattr("prefect.engine.cloud.flow_runner.Client", Client)
    prefect.utilities.logging.prefect_logger.handlers[-1].client = Client()

    with prefect.utilities.configuration.set_temporary_config({
            "logging.log_to_cloud":
            True,
            "special_key":
            42,
            "cloud.auth_token":
            "original",
    }):
        # captures config at init
        runner = CustomFlowRunner(flow=prefect.Flow("test", tasks=[log_stuff]))
        flow_state = runner.run(
            return_tasks=[log_stuff],
            task_contexts={log_stuff: dict(special_key=99)})

    assert flow_state.is_successful()
    assert flow_state.result[log_stuff].result == (42, "original")

    time.sleep(0.75)
    logs = [log for call in calls for log in call[0]]
    assert len(logs) >= 6  # actual number of logs

    loggers = [c["name"] for c in logs]
    assert set(loggers) == {
        "prefect.CloudTaskRunner",
        "prefect.CustomFlowRunner",
        "prefect.log_stuff",
    }

    task_run_ids = [c["task_run_id"] for c in logs if c["task_run_id"]]
    assert set(task_run_ids) == {"TESTME"}
Example #30
0
def test_get_flow_run_info(patch_post):
    response = {
        "flow_run_by_pk": {
            "id":
            "da344768-5f5d-4eaf-9bca-83815617f713",
            "flow_id":
            "da344768-5f5d-4eaf-9bca-83815617f713",
            "name":
            "flow-run-name",
            "version":
            0,
            "parameters": {},
            "context":
            None,
            "scheduled_start_time":
            "2019-01-25T19:15:58.632412+00:00",
            "serialized_state": {
                "type": "Pending",
                "_result": {
                    "type": "SafeResult",
                    "value": "42",
                    "result_handler": {
                        "type": "JSONResultHandler"
                    },
                },
                "message": None,
                "__version__": "0.3.3+309.gf1db024",
                "cached_inputs": None,
            },
            "task_runs": [{
                "id": "da344768-5f5d-4eaf-9bca-83815617f713",
                "task": {
                    "id": "da344768-5f5d-4eaf-9bca-83815617f713",
                    "slug": "da344768-5f5d-4eaf-9bca-83815617f713",
                },
                "version": 0,
                "serialized_state": {
                    "type": "Pending",
                    "result": None,
                    "message": None,
                    "__version__": "0.3.3+309.gf1db024",
                    "cached_inputs": None,
                },
            }],
        }
    }
    post = patch_post(dict(data=response))

    with set_temporary_config({
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token"
    }):
        client = Client()
    result = client.get_flow_run_info(flow_run_id="74-salt")
    assert isinstance(result, FlowRunInfoResult)
    assert isinstance(result.scheduled_start_time, datetime.datetime)
    assert result.scheduled_start_time.minute == 15
    assert result.scheduled_start_time.year == 2019
    assert isinstance(result.state, Pending)
    assert result.state.result == "42"
    assert result.state.message is None
    assert result.version == 0
    assert isinstance(result.parameters, dict)
    assert result.context is None