def test_no_RoleARN(runner, live_mock_server, test_settings, mocked_fetchable_git_repo, monkeypatch): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setattr(boto3, "client", mock_boto3_client) monkeypatch.setattr(wandb.sdk.launch.runner.aws, "aws_ecr_login", lambda x, y: "Login Succeeded\n") monkeypatch.setattr(wandb.docker, "push", lambda x, y: f"The push refers to repository [{x}]") kwargs = json.loads( fixture_open("launch/launch_sagemaker_config.json").read()) with runner.isolated_filesystem(): uri = "https://wandb.ai/mock_server_entity/test/runs/1" api = wandb.sdk.internal.internal_api.Api( default_settings=test_settings, load_settings=False) kwargs["uri"] = uri kwargs["api"] = api kwargs["resource_args"]["sagemaker"].pop("RoleArn", None) with pytest.raises(wandb.errors.LaunchError) as e_info: launch.run(**kwargs) assert ( str(e_info.value) == "AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker" "field of resource_args")
def test_launch_aws_sagemaker_push_image_fail_err_msg( live_mock_server, test_settings, mocked_fetchable_git_repo, monkeypatch, ): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setattr(boto3, "client", mock_boto3_client) monkeypatch.setattr(wandb.docker, "tag", lambda x, y: "") monkeypatch.setattr(wandb.sdk.launch.runner.aws, "aws_ecr_login", lambda x, y: "Login Succeeded\n") monkeypatch.setattr( wandb.docker, "push", lambda x, y: "I regret to inform you, that I have failed") api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = json.loads( fixture_open("launch/launch_sagemaker_config.json").read()) kwargs["uri"] = uri kwargs["api"] = api with pytest.raises(wandb.errors.LaunchError) as e_info: launch.run(**kwargs) assert "I regret to inform you, that I have failed" in str(e_info.value)
def test_launch_aws_sagemaker_launch_fail( live_mock_server, test_settings, mocked_fetchable_git_repo, monkeypatch, ): def mock_client_launch_fail(*args, **kwargs): if args[0] == "sagemaker": mock_sagemaker_client = MagicMock() mock_sagemaker_client.create_training_job.return_value = {} mock_sagemaker_client.stop_training_job.return_value = { "TrainingJobArn": "arn:aws:sagemaker:us-east-1:123456789012:TrainingJob/test-job-1" } mock_sagemaker_client.describe_training_job.return_value = { "TrainingJobStatus": "Completed", "TrainingJobName": "test-job-1", } return mock_sagemaker_client elif args[0] == "ecr": ecr_client = MagicMock() ecr_client.get_authorization_token.return_value = { "authorizationData": [{ "proxyEndpoint": "https://123456789012.dkr.ecr.us-east-1.amazonaws.com", }] } return ecr_client elif args[0] == "sts": sts_client = MagicMock() sts_client.get_caller_identity.return_value = { "Account": "123456789012", } return sts_client monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setattr(boto3, "client", mock_client_launch_fail) monkeypatch.setattr(wandb.docker, "tag", lambda x, y: "") monkeypatch.setattr(wandb.docker, "push", lambda x, y: f"The push refers to repository [{x}]") monkeypatch.setattr(wandb.sdk.launch.runner.aws, "aws_ecr_login", lambda x, y: "Login Succeeded\n") api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = json.loads( fixture_open("launch/launch_sagemaker_config.json").read()) kwargs["uri"] = uri kwargs["api"] = api with pytest.raises(wandb.errors.LaunchError) as e_info: launch.run(**kwargs) assert "Unable to create training job" in str(e_info.value)
def test_launch_no_server_info(live_mock_server, test_settings, mocked_fetchable_git_repo): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) api.get_run_info = MagicMock( return_value=None, side_effect=wandb.CommError("test comm error")) try: launch.run( "https://wandb.ai/mock_server_entity/test/runs/1", api, project=f"new-test", ) except wandb.errors.LaunchError as e: assert "Run info is invalid or doesn't exist" in str(e)
def test_launch_args_supersede_config_vals(live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) kwargs = { "uri": "https://wandb.ai/mock_server_entity/test/runs/1", "api": api, "project": "new_test_project", "entity": "mock_server_entity", "config": { "project": "not-this-project", "overrides": { "run_config": { "epochs": 3 }, "args": ["--epochs=2", "--heavy"], }, }, "parameters": { "epochs": 5 }, } input_kwargs = kwargs.copy() input_kwargs["parameters"] = ["epochs", 5] mock_with_run_info = launch.run(**kwargs) for arg in mock_with_run_info.args: if isinstance(arg, _project_spec.LaunchProject): assert arg.override_args["epochs"] == 5 assert arg.override_config.get("epochs") is None assert arg.target_project == "new_test_project"
def test_launch_base_case( live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend, monkeypatch, ): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) def mock_create_metadata_file(*args, **kwargs): dockerfile_contents = args[4] assert "ENV WANDB_BASE_URL=https://api.wandb.ai" in dockerfile_contents assert f"ENV WANDB_API_KEY={api.api_key}" in dockerfile_contents assert "ENV WANDB_PROJECT=test" in dockerfile_contents assert "ENV WANDB_ENTITY=mock_server_entity" in dockerfile_contents _project_spec.create_metadata_file(*args, **kwargs) monkeypatch.setattr( wandb.sdk.launch.docker, "create_metadata_file", mock_create_metadata_file, ) expected_config = {} uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = { "uri": uri, "api": api, "entity": "mock_server_entity", "project": "test", } mock_with_run_info = launch.run(**kwargs) check_mock_run_info(mock_with_run_info, expected_config, kwargs)
def test_launch_aws_sagemaker( live_mock_server, test_settings, mocked_fetchable_git_repo, monkeypatch, ): def mock_create_metadata_file(*args, **kwargs): dockerfile_contents = args[4] expected_entrypoint = 'ENTRYPOINT ["sh", "train"]' assert expected_entrypoint in dockerfile_contents, dockerfile_contents _project_spec.create_metadata_file(*args, **kwargs) monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setattr(boto3, "client", mock_boto3_client) monkeypatch.setattr( wandb.sdk.launch.docker, "create_metadata_file", mock_create_metadata_file, ) monkeypatch.setattr(wandb.docker, "tag", lambda x, y: "") monkeypatch.setattr(wandb.docker, "push", lambda x, y: f"The push refers to repository [{x}]") monkeypatch.setattr(wandb.sdk.launch.runner.aws, "aws_ecr_login", lambda x, y: "Login Succeeded\n") api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = json.loads( fixture_open("launch/launch_sagemaker_config.json").read()) kwargs["uri"] = uri kwargs["api"] = api run = launch.run(**kwargs) assert run.training_job_name == "test-job-1"
def test_launch_code_artifact(runner, live_mock_server, test_settings, monkeypatch, mock_load_backend): def download_func(dst_dir): with open(os.path.join(dst_dir, "train.py"), "w") as f: f.write(fixture_open("train.py").read()) with open(os.path.join(dst_dir, "requirements.txt"), "w") as f: f.write(fixture_open("requirements.txt").read()) with open(os.path.join(dst_dir, "patch.txt"), "w") as f: f.write("testing") run_with_artifacts = mock.MagicMock() code_artifact = mock.MagicMock() code_artifact.type = "code" code_artifact.download = download_func run_with_artifacts.logged_artifacts.return_value = [code_artifact] monkeypatch.setattr(wandb.PublicApi, "run", lambda *arg, **kwargs: run_with_artifacts) api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) expected_config = {} uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = { "uri": uri, "api": api, "entity": "mock_server_entity", "project": "test", } mock_with_run_info = launch.run(**kwargs) check_mock_run_info(mock_with_run_info, expected_config, kwargs)
def test_failed_aws_cred_login(runner, live_mock_server, monkeypatch, test_settings, mocked_fetchable_git_repo): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setattr(boto3, "client", mock_boto3_client) monkeypatch.setattr(wandb.sdk.launch.runner.aws, "aws_ecr_login", lambda x, y: "Login Failed\n") kwargs = json.loads( fixture_open("launch/launch_sagemaker_config.json").read()) with runner.isolated_filesystem(): uri = "https://wandb.ai/mock_server_entity/test/runs/1" api = wandb.sdk.internal.internal_api.Api( default_settings=test_settings, load_settings=False) kwargs["uri"] = uri kwargs["api"] = api with pytest.raises(wandb.errors.LaunchError): launch.run(**kwargs)
def test_launch_notebook(live_mock_server, test_settings, mocked_fetchable_git_repo_ipython): live_mock_server.set_ctx({"return_jupyter_in_run_info": True}) api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) run = launch.run( "https://wandb.ai/mock_server_entity/test/runs/jupyter1", api, project="new-test", ) assert str(run.get_status()) == "finished"
def test_launch_full_build_new_image(live_mock_server, test_settings, mocked_fetchable_git_repo): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) random_id = util.generate_id() run = launch.run( "https://wandb.ai/mock_server_entity/test/runs/1", api, project=f"new-test-{random_id}", ) assert str(run.get_status()) == "finished"
def test_launch_unowned_project(live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) kwargs = { "uri": "https://wandb.ai/other_user/test_project/runs/1", "api": api, "project": "new_test_project", "entity": "mock_server_entity", } expected_config = {} mock_with_run_info = launch.run(**kwargs) check_mock_run_info(mock_with_run_info, expected_config, kwargs)
def test_bare_wandb_uri(live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) expected_config = {} uri = "/mock_server_entity/test/runs/12345678" kwargs = { "uri": uri, "api": api, "entity": "mock_server_entity", "project": "test", } mock_with_run_info = launch.run(**kwargs) kwargs["uri"] = live_mock_server.base_url + uri check_mock_run_info(mock_with_run_info, expected_config, kwargs)
def test_launch_metadata(live_mock_server, test_settings, mocked_fetchable_git_repo): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) # for now using mocks instead of mock server def mocked_download_url(*args, **kwargs): if args[1] == "wandb-metadata.json": return {"url": "urlForCodePath"} elif args[1] == "code/main2.py": return {"url": "main2.py"} elif args[1] == "requirements.txt": return {"url": "requirements"} api.download_url = MagicMock(side_effect=mocked_download_url) def mocked_file_download_request(url): class MockedFileResponder: def __init__(self, url): self.url: str = url def json(self): if self.url == "urlForCodePath": return {"codePath": "main2.py"} def iter_content(self, chunk_size): if self.url == "requirements": return [b"numpy==1.19.5\n"] elif self.url == "main2.py": return [ b"import wandb\n", b"import numpy\n", b"print('ran server fetched code')\n", ] return 200, MockedFileResponder(url) api.download_file = MagicMock(side_effect=mocked_file_download_request) run = launch.run( "https://wandb.ai/mock_server_entity/test/runs/1", api, project="test-another-new-project", ) assert str(run.get_status()) == "finished"
def test_launch_resource_args(live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) expected_config = {} uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = { "uri": uri, "api": api, "entity": "mock_server_entity", "project": "test", "resource": "local", "resource_args": { "a": "b", "c": "d" }, } mock_with_run_info = launch.run(**kwargs) check_mock_run_info(mock_with_run_info, expected_config, kwargs)
def test_launch_run_config_in_spec(live_mock_server, test_settings, mocked_fetchable_git_repo, mock_load_backend): api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) kwargs = { "uri": "https://wandb.ai/mock_server_entity/test/runs/1", "api": api, "project": "new_test_project", "entity": "mock_server_entity", "config": { "overrides": { "run_config": { "epochs": 3 } } }, } expected_runner_config = {} mock_with_run_info = launch.run(**kwargs) check_mock_run_info(mock_with_run_info, expected_runner_config, kwargs)
def test_sagemaker_specified_image(live_mock_server, test_settings, mocked_fetchable_git_repo, monkeypatch, capsys): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setattr(boto3, "client", mock_boto3_client) api = wandb.sdk.internal.internal_api.Api(default_settings=test_settings, load_settings=False) uri = "https://wandb.ai/mock_server_entity/test/runs/1" kwargs = json.loads( fixture_open("launch/launch_sagemaker_config.json").read()) kwargs["uri"] = uri kwargs["api"] = api kwargs["resource_args"]["sagemaker"]["AlgorithmSpecification"][ "TrainingImage"] = "my-test_image" kwargs["resource_args"]["sagemaker"]["AlgorithmSpecification"][ "TrainingInputMode"] = "File" run = launch.run(**kwargs) stderr = capsys.readouterr().err assert ( "Launching sagemaker job with user provided ECR image, this image will not be able to swap artifacts" in stderr) assert run.training_job_name == "test-job-1"