def test_custom_tls_certs() -> None: with run_test_server(("127.0.0.1", 12345), cert=UNTRUSTED_CERT, key=UNTRUSTED_KEY) as untrusted_url: with open(UNTRUSTED_CERT) as f: untrusted_pem = f.read() for kwargs, raises in [ ({ "noverify": True }, False), ({ "noverify": False }, True), ({ "cert_pem": untrusted_pem }, False), ({}, True), ]: assert isinstance(kwargs, dict) cert = certs.Cert(**kwargs) # Trusted domains should always work. request.get(TRUSTED_DOMAIN, "", authenticated=False, cert=cert) with contextlib.ExitStack() as ctx: if raises: ctx.enter_context( pytest.raises(requests.exceptions.SSLError)) request.get(untrusted_url, "", authenticated=False, cert=cert)
def test_custom_tls_certs() -> None: for bundle, raises in [ (False, False), (True, True), (UNTRUSTED_CERT_FILE, False), (None, True), ]: request.set_master_cert_bundle(bundle) # type: ignore request.get(TRUSTED_DOMAIN, "", authenticated=False) with contextlib.ExitStack() as ctx: if raises: ctx.enter_context(pytest.raises(requests.exceptions.SSLError)) request.get(UNTRUSTED_DOMAIN, "", authenticated=False)
def trial_prep(info: det.ClusterInfo, cert: certs.Cert) -> None: trial_info = det.TrialInfo._from_env() trial_info._to_file() path = f"api/v1/experiments/{trial_info.experiment_id}/model_def" resp = None try: resp = request.get(info.master_url, path=path, cert=cert) resp.raise_for_status() except Exception: # Since this is the very first api call in the entrypoint script, and the call is made # before you can debug with a startup hook, we offer an overly-detailed explanation to help # sysadmins debug their cluster. resp_content = str(resp and resp.content) noverify = info.master_cert_file == "noverify" cert_content = None if noverify else info.master_cert_file if cert_content is not None: with open(cert_content) as f: cert_content = f.read() print( "Failed to download model definition from master. This may be due to an address\n" "resolution problem, a certificate problem, a firewall problem, or some other\n" "networking error.\n" "Debug information:\n" f" master_url: {info.master_url}\n" f" endpoint: {path}\n" f" tls_verify_name: {info.master_cert_name}\n" f" tls_noverify: {noverify}\n" f" tls_cert: {cert_content}\n" f" response code: {resp and resp.status_code}\n" f" response content: {resp_content}\n") raise tgz = base64.b64decode(resp.json()["b64Tgz"]) with tarfile.open(fileobj=io.BytesIO(tgz), mode="r:gz") as model_def: # Ensure all members of the tarball resolve to subdirectories. for path in model_def.getnames(): if os.path.relpath(path).startswith("../"): raise ValueError( f"'{path}' in tarball would expand to a parent directory") model_def.extractall(path=constants.MANAGED_TRAINING_MODEL_COPY) model_def.extractall(path=".")
import distutils.util import io import os import tarfile from determined import constants from determined.common.api import certs, request if __name__ == "__main__": exp_id = os.environ["DET_EXPERIMENT_ID"] master_addr = os.environ["DET_MASTER_ADDR"] master_port = os.environ["DET_MASTER_PORT"] use_tls = distutils.util.strtobool(os.environ.get("DET_USE_TLS", "false")) master_url = f"http{'s' if use_tls else ''}://{master_addr}:{master_port}" certs.cli_cert = certs.default_load(master_url=master_url) resp = request.get(master_url, f"api/v1/experiments/{exp_id}/model_def") resp.raise_for_status() tgz = base64.b64decode(resp.json()["b64Tgz"]) with tarfile.open(fileobj=io.BytesIO(tgz), mode="r:gz") as model_def: # Ensure all members of the tarball resolve to subdirectories. for path in model_def.getnames(): if os.path.relpath(path).startswith("../"): raise ValueError( f"'{path}' in tarball would expand to a parent directory") model_def.extractall(path=constants.MANAGED_TRAINING_MODEL_COPY) model_def.extractall(path=".")