Exemple #1
0
def test_inference_job_scale(args):
    if utils.get_launcher(args.config) == "controller":
        return
    job_spec = utils.gen_default_job_description("inference", args.email,
                                                args.uid, args.vc, cmd="sleep 600")

    with utils.run_job(args.rest, job_spec) as job:
        job_id = job.jid
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        deployment_name = job_id + "-deployment"
        deployment = utils.kube_get_deployment(args.config, "default", deployment_name)
        assert 1 == deployment.spec.replicas

        desired_replicas = 2
        logger.info("scale up job %s to %d" % (job_id, desired_replicas))
        resp = utils.scale_job(args.rest, args.email, job_id, desired_replicas)
        assert "Success" == resp

        time.sleep(30)
        deployment = utils.kube_get_deployment(args.config, "default", deployment_name)
        assert desired_replicas == deployment.spec.replicas

        desired_replicas = 1
        logger.info("scale down job %s to %d" % (job_id, desired_replicas))
        resp = utils.scale_job(args.rest, args.email, job_id, desired_replicas)
        assert "Success" == resp

        time.sleep(30)
        deployment = utils.kube_get_deployment(args.config, "default", deployment_name)
        assert desired_replicas == deployment.spec.replicas
def test_op_job(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        job_id = job.jid
        utils.block_until_state_in(args.rest, job_id, {"running"})

        # Try to ApproveJob
        logger.info("approve job %s" % job_id)
        resp = utils.approve_job(args.rest, args.email, job_id)
        assert "Cannot approve the job. Job ID:%s" % job_id == resp["result"]

        # PauseJob
        logger.info("pause job %s" % job_id)
        resp = utils.pause_job(args.rest, args.email, job_id)
        assert "Success, the job is scheduled to be paused." == resp["result"]

        # ResumeJob
        utils.block_until_state_in(args.rest, job_id, {"paused"})
        logger.info("resume job %s" % job_id)
        resp = utils.resume_job(args.rest, args.email, job_id)
        assert "Success, the job is scheduled to be resumed." == resp["result"]

        # KillJob
        utils.block_until_state_in(args.rest, job_id, {"running"})
        logger.info("kill job %s" % job_id)
        resp = utils.kill_job(args.rest, args.email, job_id)
        assert "Success, the job is scheduled to be terminated." == resp[
            "result"]

        state = job.block_until_state_not_in({"killing"})
        assert "killed" == state
Exemple #3
0
def test_inference_job_use_alias_to_run(args):
    job_spec = utils.gen_default_job_description(
        "inference",
        args.email,
        args.uid,
        args.vc,
        cmd="echo dummy `whoami` ; sleep 120")

    def satisified(expected, times, log):
        """ return True on found `expected` occurs `times` times in `log` """
        start = 0
        for _ in range(times):
            end = log.find(expected, start)
            if end == -1:
                return False
            start = end + 1
        return True

    expected_word = "dummy %s" % (args.email.split("@")[0])

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        for _ in range(300):
            log = utils.get_job_log(args.rest, args.email, job.jid)
            if satisified(expected_word, 2, log):
                break
            time.sleep(0.5)

        assert satisified(expected_word, 2, log), 'log is %s' % (log)
def test_blobfuse(args):
    path = "/tmp/blob/${DLTS_JOB_ID}"
    cmd = "echo dummy > %s; cat %s ; rm %s ;" % (path, path, path)

    job_spec = utils.gen_default_job_description("regular",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd=cmd)

    job_spec["plugins"] = utils.load_azure_blob_config(args.config,
                                                       "/tmp/blob")

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling", "running"})
        assert state == "finished", "state is not finished, but %s" % state

        for _ in range(5):
            log = utils.get_job_log(args.rest, args.email, job.jid)
            if log.find("dummy") != -1:
                break
            time.sleep(0.5)

        assert log.find("dummy") != -1, "could not find dummy in log %s" % (
            log)
Exemple #5
0
def test_inference_job_running(args):
    envs = {
        "DLWS_HOST_NETWORK": "",
        "DLTS_HOST_NETWORK": "",
        "DLWS_NUM_GPU_PER_WORKER": "1",
        "DLTS_NUM_GPU_PER_WORKER": "1",
        "DLWS_VC_NAME": str(args.vc),
        "DLTS_VC_NAME": str(args.vc),
        "DLWS_UID": str(args.uid),
        "DLTS_UID": str(args.uid),
        "DLWS_USER_NAME": args.email.split("@")[0],
        "DLTS_USER_NAME": args.email.split("@")[0],
        "DLWS_USER_EMAIL": args.email,
        "DLTS_USER_EMAIL": args.email,
        "DLWS_ROLE_NAME": "master",
        "DLTS_ROLE_NAME": "master",
        "DLWS_JOB_ID": "unknown",
        "DLTS_JOB_ID": "unknown",
    }

    job_spec = utils.gen_default_job_description("inference", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        envs["DLWS_JOB_ID"] = job.jid
        envs["DLTS_JOB_ID"] = job.jid

        pods = utils.kube_get_pods(args.config, "default", "jobId=" + job.jid)
        assert len(pods) == 2

        for pod in pods:
            envs["DLWS_ROLE_NAME"] = pod.metadata.labels["jobRole"]
            envs["DLTS_ROLE_NAME"] = pod.metadata.labels["jobRole"]
            pod_name = pod.metadata.name
            container_name = pod.spec.containers[0].name

            cmd = ["bash", "-c"]

            remain_cmd = [
                "printf %s= ; printenv %s" % (key, key)
                for key, _ in envs.items()
            ]

            cmd.append(";".join(remain_cmd))

            code, output = utils.kube_pod_exec(args.config, "default", pod_name,
                                               container_name, cmd)

            logger.debug("cmd %s output for %s.%s is %s", cmd, pod_name,
                         container_name, output)

            for key, val in envs.items():
                expected_output = "%s=%s" % (key, val)
                assert output.find(
                    expected_output) != -1, "could not find %s in log %s" % (
                        expected_output, output)
def test_data_job_running(args):
    expected_state = "finished"
    expected_word = "wantThisInLog"
    cmd = "mkdir -p /tmp/dlts_test_dir; " \
          "echo %s > /tmp/dlts_test_dir/testfile; " \
          "cd /DataUtils; " \
          "./copy_data.sh /tmp/dlts_test_dir adl://indexserveplatform-experiment-c09.azuredatalakestore.net/local/dlts_test_dir True 4194304 4 2 >/dev/null 2>&1;" \
          "./copy_data.sh adl://indexserveplatform-experiment-c09.azuredatalakestore.net/local/dlts_test_dir /tmp/dlts_test_dir_copyback False 33554432 4 2 >/dev/null 2>&1;" \
          "cat /tmp/dlts_test_dir_copyback/testfile; " % expected_word

    image = "indexserveregistry.azurecr.io/dlts-data-transfer-image:latest"

    job_spec = utils.gen_default_job_description("data",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd=cmd,
                                                 image=image)
    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling", "running"})
        assert expected_state == state

        for _ in range(10):
            log = utils.get_job_log(args.rest, args.email, job.jid)
            if expected_word in log:
                break
            time.sleep(0.5)
        assert expected_word in log, 'assert {} in {}'.format(
            expected_word, log)
def test_no_resource_info(args):
    expected = "Insufficient nvidia.com/gpu"

    job_spec = utils.gen_default_job_description("regular",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 resourcegpu=5)
    # TODO hardcode 5 here, may need to change to `gpu_per_host + 1` manually
    # when testing other clusters

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in({"unapproved", "queued"})
        assert state == "scheduling"

        for _ in range(50):
            details = utils.get_job_detail(args.rest, args.email, job.jid)

            message = utils.walk_json_safe(details, "jobStatusDetail", 0,
                                           "message")
            if expected in message:
                break

            time.sleep(0.5)
        assert expected in message, "unexpected detail " + details
def test_distributed_job_ssh(args):
    job_spec = utils.gen_default_job_description("distributed", args.email,
                                                 args.uid, args.vc)
    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 2

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        for endpoint_id in endpoints_ids:
            ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                     job.jid, endpoint_id)
            logger.debug("endpoint_id is %s, endpoints resp is %s",
                         endpoint_id, ssh_endpoint)

            ssh_host = "%s.%s" % (ssh_endpoint["nodeName"],
                                  ssh_endpoint["domain"])
            ssh_port = ssh_endpoint["port"]

            # exec into jobmanager to execute ssh to avoid firewall
            job_manager_pod = utils.kube_get_pods(args.config, "default",
                                                  "app=jobmanager")[0]
            job_manager_pod_name = job_manager_pod.metadata.name

            alias = args.email.split("@")[0]

            cmd_prefix = [
                "ssh",
                "-i",
                "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
                "-p",
                ssh_port,
                "-o",
                "StrictHostKeyChecking=no",
                "-o",
                "LogLevel=ERROR",
                "%s@%s" % (alias, ssh_host),
                "--",
            ]

            # check they can connect to each other
            for role in ["ps-0", "worker-0"]:
                cmd = copy.deepcopy(cmd_prefix)
                cmd.extend([
                    "ssh", role, "-o", "LogLevel=ERROR", "--", "echo", "dummy"
                ])
                code, output = utils.kube_pod_exec(args.config, "default",
                                                   job_manager_pod_name,
                                                   "jobmanager", cmd)
                logger.debug("code %s, output '%s'", code, output)
                assert code == 0
                assert output == "dummy\n"
def test_job_fail(args):
    expected_state = "failed"
    cmd = "false"

    job_spec = utils.gen_default_job_description("regular",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd=cmd)
    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling", "running"})
        assert expected_state == state
def test_ssh_do_not_expose_private_key(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 1
        endpoint_id = endpoints_ids[0]

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)

        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]

        # exec into jobmanager to execute ssh to avoid firewall
        job_manager_pod = utils.kube_get_pods(args.config, "default",
                                              "app=jobmanager")[0]
        job_manager_pod_name = job_manager_pod.metadata.name

        alias = args.email.split("@")[0]

        ssh_cmd = [
            "ssh",
            "-i",
            "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
            "-p",
            ssh_port,
            "-o",
            "StrictHostKeyChecking=no",
            "-o",
            "LogLevel=ERROR",
            "%s@%s" % (alias, ssh_host),
            "--",
            "echo a ; printenv DLTS_SSH_PRIVATE_KEY ; echo b",
        ]
        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           ssh_cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)

        expected = "a\nb"
        assert expected in output, "could not find %s in output %s" % (
            expected, output)
def test_gpu_type_override(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)
    # wrong gpu type
    job_spec["gpuType"] = "V100"

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in({"unapproved", "queued"})
        assert state in ["scheduling", "running"]

        pod = utils.kube_get_pods(args.config, "default",
                                  "jobId=%s" % job.jid)[0]

        # gpu type should be overriden by the correct one
        assert pod.metadata.labels.get("gpuType") == "P40"
def test_job_insight(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        payload = {"messages": ["dummy"]}
        resp = utils.set_job_insight(args.rest, args.email, job.jid, payload)
        assert resp.status_code == 200

        insight = utils.get_job_insight(args.rest, args.email, job.jid)
        assert payload == insight
def test_ssh_cuda_visible_devices(args, job_spec, expected):
    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 1
        endpoint_id = endpoints_ids[0]

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)

        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]

        # exec into jobmanager to execute ssh to avoid firewall
        job_manager_pod = utils.kube_get_pods(args.config, "default",
                                              "app=jobmanager")[0]
        job_manager_pod_name = job_manager_pod.metadata.name

        alias = args.email.split("@")[0]

        ssh_cmd = [
            "ssh",
            "-i",
            "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
            "-p",
            ssh_port,
            "-o",
            "StrictHostKeyChecking=no",
            "-o",
            "LogLevel=ERROR",
            "%s@%s" % (alias, ssh_host),
            "--",
            "echo a; env | grep CUDA_VISIBLE_DEVICES;",
            "grep CUDA_VISIBLE_DEVICES ~/.ssh/environment; echo b",
        ]
        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           ssh_cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)

        assert expected in output, "could not find %s in output %s" % (
            expected, output)
def test_distributed_job_mountpoints(args):
    job_spec = utils.gen_default_job_description("distributed", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in({"unapproved", "queued"})
        assert state in ["scheduling", "running"]

        pods = utils.kube_get_pods(args.config, "default",
                                   "jobId=%s" % job.jid)

        mps = utils.load_cluster_nfs_mountpoints(args, job.jid)
        mps.extend(utils.load_system_mountpoints(args))
        mps.extend(utils.load_infiniband_mounts(args))

        for pod in pods:
            for mp in mps:
                assert utils.mountpoint_in_pod(mp, pod), \
                    "mountpoint %s not in distributed job %s" % (mp, job.jid)
def test_sudo_installed(args):
    cmd = "sudo ls"
    image = "pytorch/pytorch:latest"  # no sudo installed in this image

    job_spec = utils.gen_default_job_description(
        "regular",
        args.email,
        args.uid,
        args.vc,
        cmd=cmd,
        image=image,
    )

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling", "running"})
        log = utils.get_job_log(args.rest, args.email, job.jid)

        assert state == "finished"
def test_blobfuse(args):
    job_spec = utils.gen_default_job_description("distributed", args.email,
                                                 args.uid, args.vc)

    job_spec["plugins"] = utils.load_azure_blob_config(args.config,
                                                       "/tmp/blob")

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ps_label = "jobId=%s,jobRole=ps" % job.jid
        pods = utils.kube_get_pods(args.config, "default", ps_label)
        assert len(pods) == 1

        ps_pod_name = pods[0].metadata.name
        ps_container_name = pods[0].spec.containers[0].name
        msg = "this is dummy from ps"
        ps_cmd = ["bash", "-c", "echo %s > /tmp/blob/${DLWS_JOB_ID}" % (msg)]

        code, output = utils.kube_pod_exec(args.config, "default", ps_pod_name,
                                           ps_container_name, ps_cmd)
        assert code == 0, "code is %d, output is %s" % (code, output)

        worker_label = "jobId=%s,jobRole=worker" % job.jid
        pods = utils.kube_get_pods(args.config, "default", worker_label)
        assert len(pods) == 1

        worker_pod_name = pods[0].metadata.name
        worker_container_name = pods[0].spec.containers[0].name
        worker_cmd = [
            "bash", "-c",
            "cat /tmp/blob/${DLWS_JOB_ID} ; rm /tmp/blob/${DLWS_JOB_ID}"
        ]

        code, output = utils.kube_pod_exec(args.config, "default",
                                           worker_pod_name,
                                           worker_container_name, worker_cmd)
        assert code == 0, "code is %d, output is %s" % (code, output)
        assert msg + "\n" == output, "code is %d, output is %s" % (code,
                                                                   output)
def test_list_all_jobs(args):
    job_spec = utils.gen_default_job_description("regular",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd="")

    # All jobs should include finished jobs
    with utils.run_job(args.rest, job_spec) as job:
        job_id = job.jid
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling", "running"})
        assert state == "finished"

    resp = utils.get_job_list(args.rest, args.email, args.vc, "all", 10)
    finished_jobs = resp.get("finishedJobs", None)
    assert isinstance(finished_jobs, list)

    finished_job_ids = [job["jobId"] for job in finished_jobs]
    assert job_id in finished_job_ids
def test_regular_job_ssh(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 1
        endpoint_id = endpoints_ids[0]

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)
        logger.debug("endpoints resp is %s", ssh_endpoint)

        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]

        # exec into jobmanager to execute ssh to avoid firewall
        job_manager_pod = utils.kube_get_pods(args.config, "default",
                                              "app=jobmanager")[0]
        job_manager_pod_name = job_manager_pod.metadata.name

        alias = args.email.split("@")[0]

        cmd = [
            "ssh", "-i",
            "/dlwsdata/work/%s/.ssh/id_rsa" % alias, "-p", ssh_port, "-o",
            "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR",
            "%s@%s" % (alias, ssh_host), "--", "echo", "dummy"
        ]
        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)
        assert output == "dummy\n", "output is %s" % (output)
def test_image_pull_msg(args):
    expected = "ImagePullBackOff"

    job_spec = utils.gen_default_job_description("distributed",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 image="not_exist_image")
    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in({"unapproved", "queued"})
        assert state == "scheduling"

        for _ in range(50):
            details = utils.get_job_detail(args.rest, args.email, job.jid)

            message = utils.walk_json_safe(details, "jobStatusDetail", 0,
                                           "message")
            if expected in message:
                break

            time.sleep(0.5)
        assert expected in message, "unexpected detail " + details
def test_regular_job_mountpoints(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in({"unapproved", "queued"})
        assert state in ["scheduling", "running"]

        pod = utils.kube_get_pods(args.config, "default",
                                  "jobId=%s" % job.jid)[0]

        mps = utils.load_cluster_nfs_mountpoints(args, job.jid)
        mps.extend(utils.load_system_mountpoints(args))

        for mp in mps:
            assert utils.mountpoint_in_pod(mp, pod), \
                "mountpoint %s not in regular job %s" % (mp, job.jid)

        # Regular job should not have IB mounted
        ib_mps = utils.load_infiniband_mounts(args)
        for mp in ib_mps:
            assert not utils.mountpoint_in_pod(mp, pod), \
                "infiniband mountpoint %s in regular job %s" % (mp, job.jid)
def test_distributed_job_running(args, preemptable=False):
    expected = "wantThisInLog"
    cmd = "echo %s ; sleep 120" % expected

    job_spec = utils.gen_default_job_description("distributed",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 preemptable=preemptable,
                                                 cmd=cmd)
    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        for _ in range(50):
            log = utils.get_job_log(args.rest, args.email, job.jid)

            if expected in log:
                break

            time.sleep(0.5)
        assert expected in log, "assert {} in {}".format(expected, log)
def test_job_priority(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)
    with utils.run_job(args.rest, job_spec) as job:
        # wait until running to avoid state change race
        state = job.block_until_state_not_in({"unapproved", "queued"})
        assert state in ["scheduling", "running"]

        # invalid payload
        resp = utils.set_job_priorities(args.rest, args.email, None)
        assert resp.status_code == 400

        # unauthorized user cannot change priority
        resp = utils.set_job_priorities(args.rest, "unauthorized_user",
                                        {job.jid: 101})
        assert resp.status_code == 403
        priority = utils.get_job_priorities(args.rest)[job.jid]
        assert priority == 100

        # job owner can change priority
        resp = utils.set_job_priorities(args.rest, args.email, {job.jid: 101})
        assert resp.status_code == 200
        priority = utils.get_job_priorities(args.rest)[job.jid]
        assert priority == 101
def test_do_not_expose_private_key(args):
    cmd = "echo a ; printenv DLTS_SSH_PRIVATE_KEY ; echo b"

    job_spec = utils.gen_default_job_description("regular",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd=cmd)

    with utils.run_job(args.rest, job_spec) as job:
        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling", "running"})
        assert state == "finished"

        expected = "a\nb"

        for _ in range(50):
            log = utils.get_job_log(args.rest, args.email, job.jid)

            if expected in log:
                break

            time.sleep(0.5)
        assert expected in log, 'assert {} in {}'.format(expected, log)
def test_distributed_with_default_cmd(args):
    cmd = """
##################################################
# DeepScale 1.0 redirect (Start)
##################################################
set -e

hostname
whoami

# Create DeepScale 1.0 redirect on master node
sudo chmod -R 777 /root && mkdir -p /root/.ssh && rm -f /root/.ssh/config && ln -s ~/.ssh/config /root/.ssh/config && echo "/root/.ssh/config created" && ls -l /root/.ssh
mkdir -p /opt && sudo rm -f /opt/hostfile && sudo ln -s /job/hostfile /opt/hostfile && cat /opt/hostfile

# Create DeepScale 1.0 redirect for all workers
for i in $(seq 0 $(( ${DLWS_NUM_WORKER} - 1 ))); do
    echo "Creating DeepScale 1.0 redirect for worker ${i}"
    ssh worker-${i} "sudo chmod -R 777 /root && mkdir -p /root/.ssh && rm -f /root/.ssh/config && ln -s ~/.ssh/config /root/.ssh/config && echo "/root/.ssh/config created" && ls -l /root/.ssh"
    ssh worker-${i} "mkdir -p /opt && sudo rm -f /opt/hostfile && sudo ln -s /job/hostfile /opt/hostfile && cat /opt/hostfile"
done
##################################################
# DeepScale 1.0 redirect (End)
##################################################

##################################################
# Unlimit memlock (Start)
##################################################
for i in $(seq 0 $(( ${DLWS_NUM_WORKER} - 1 ))); do
    echo "Creating redirect for worker ${i}"
    ssh worker-${i} "sudo bash -c 'echo -e \"*                soft   memlock         unlimited\n*                hard   memlock         unlimited\" | cat >> /etc/security/limits.conf'"
done
##################################################
# Unlimit memlock (End)
##################################################

##################################################
# User command starts here
##################################################
sleep infinity"""
    job_spec = utils.gen_default_job_description("distributed", args.email,
                                                 args.uid, args.vc)
    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 2

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        for endpoint_id in endpoints_ids:
            ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                     job.jid, endpoint_id)
            logger.debug("endpoint_id is %s, endpoints resp is %s",
                         endpoint_id, ssh_endpoint)

            ssh_host = "%s.%s" % (ssh_endpoint["nodeName"],
                                  ssh_endpoint["domain"])
            ssh_port = ssh_endpoint["port"]

            # exec into jobmanager to execute ssh to avoid firewall
            job_manager_pod = utils.kube_get_pods(args.config, "default",
                                                  "app=jobmanager")[0]
            job_manager_pod_name = job_manager_pod.metadata.name

            alias = args.email.split("@")[0]

            cmd_prefix = [
                "ssh",
                "-i",
                "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
                "-p",
                ssh_port,
                "-o",
                "StrictHostKeyChecking=no",
                "-o",
                "LogLevel=ERROR",
                "%s@%s" % (alias, ssh_host),
                "--",
            ]

            # check they can connect to each other
            for role in ["ps-0", "worker-0"]:
                cmd = copy.deepcopy(cmd_prefix)
                cmd.extend([
                    "ssh", role, "-o", "LogLevel=ERROR", "--", "echo", "dummy"
                ])
                code, output = utils.kube_pod_exec(args.config, "default",
                                                   job_manager_pod_name,
                                                   "jobmanager", cmd)
                logger.debug("code %s, output '%s'", code, output)
                assert code == 0
                assert output == "dummy\n"
def test_fault_tolerance(args):
    # Job is only retried when launcher is controller.
    if utils.get_launcher(args.config) == "python":
        return

    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 1
        endpoint_id = endpoints_ids[0]

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)
        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]

        logger.info("current ssh endpoint is %s:%s", ssh_host, ssh_port)

        pod = utils.kube_get_pods(args.config, "default",
                                  "jobId=%s" % (job.jid))[0]
        utils.kube_delete_pod(args.config, "default", pod.metadata.name)

        ssh_endpoint = utils.wait_endpoint_state(args.rest,
                                                 args.email,
                                                 job.jid,
                                                 endpoint_id,
                                                 state="pending")

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)

        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]
        logger.info("current ssh endpoint is %s:%s", ssh_host, ssh_port)

        # exec into jobmanager to execute ssh to avoid firewall
        job_manager_pod = utils.kube_get_pods(args.config, "default",
                                              "app=jobmanager")[0]
        job_manager_pod_name = job_manager_pod.metadata.name

        alias = args.email.split("@")[0]

        cmd = [
            "ssh", "-i",
            "/dlwsdata/work/%s/.ssh/id_rsa" % alias, "-p", ssh_port, "-o",
            "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR",
            "%s@%s" % (alias, ssh_host), "--", "echo", "dummy"
        ]
        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)
        assert output == "dummy\n", "output is %s" % (output)
def test_distributed_job_env(args):
    envs = {
        "DLWS_HOST_NETWORK": "enable",
        "DLTS_HOST_NETWORK": "enable",
        "DLWS_NUM_PS": "1",
        "DLTS_NUM_PS": "1",
        "DLWS_NUM_WORKER": "1",
        "DLTS_NUM_WORKER": "1",
        "DLWS_NUM_GPU_PER_WORKER": "0",
        "DLTS_NUM_GPU_PER_WORKER": "0",
        "DLWS_VC_NAME": str(args.vc),
        "DLTS_VC_NAME": str(args.vc),
        "DLWS_UID": str(args.uid),
        "DLTS_UID": str(args.uid),
        "DLWS_USER_NAME": args.email.split("@")[0],
        "DLTS_USER_NAME": args.email.split("@")[0],
        "DLWS_USER_EMAIL": args.email,
        "DLTS_USER_EMAIL": args.email,
        "DLWS_ROLE_NAME": "master",
        "DLTS_ROLE_NAME": "master",
        "DLWS_JOB_ID": "unknown",
        "DLTS_JOB_ID": "unknown",
    }

    job_spec = utils.gen_default_job_description("distributed",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd="sleep infinity")
    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"
        envs["DLWS_JOB_ID"] = job.jid
        envs["DLTS_JOB_ID"] = job.jid

        for endpoint_id in endpoints_ids:
            ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                     job.jid, endpoint_id)
            logger.debug("endpoints resp is %s", ssh_endpoint)

            ssh_host = "%s.%s" % (ssh_endpoint["nodeName"],
                                  ssh_endpoint["domain"])
            ssh_port = ssh_endpoint["port"]
            ssh_id = ssh_endpoint["id"]

            role_idx = ssh_id.split("-")[-2]
            match = re.match("([a-z]+)([0-9]+)", role_idx)
            assert match is not None, "%s is not role index name" % (role_idx)

            role, idx = match.groups()

            envs["DLWS_ROLE_NAME"] = role
            envs["DLTS_ROLE_NAME"] = role
            envs["DLWS_ROLE_IDX"] = idx
            envs["DLTS_ROLE_IDX"] = idx

            bash_cmd = ";".join([
                "printf '%s=' ; printenv %s" % (key, key)
                for key, _ in envs.items()
            ])

            # exec into jobmanager to execute ssh to avoid firewall
            job_manager_pod = utils.kube_get_pods(args.config, "default",
                                                  "app=jobmanager")[0]
            job_manager_pod_name = job_manager_pod.metadata.name

            alias = args.email.split("@")[0]

            ssh_cmd = [
                "ssh",
                "-i",
                "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
                "-p",
                str(ssh_port),
                "-o",
                "StrictHostKeyChecking=no",
                "-o",
                "LogLevel=ERROR",
                "%s@%s" % (alias, ssh_host),
                "--",
            ]
            ssh_cmd.append(bash_cmd)

            code, output = utils.kube_pod_exec(args.config, "default",
                                               job_manager_pod_name,
                                               "jobmanager", ssh_cmd)

            logger.debug("cmd %s code is %s, output is %s", " ".join(ssh_cmd),
                         code, output)

            for key, val in envs.items():
                expected_output = "%s=%s" % (key, val)
                assert output.find(
                    expected_output) != -1, "could not find %s in log %s" % (
                        expected_output, output)
def test_ssh_cuda_visible_devices(args):

    job_spec = utils.gen_default_job_description("distributed",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd="sleep infinity",
                                                 resourcegpu=4)
    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        for endpoint_id in endpoints_ids:
            ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                     job.jid, endpoint_id)
            logger.debug("endpoints resp is %s", ssh_endpoint)

            ssh_host = "%s.%s" % (ssh_endpoint["nodeName"],
                                  ssh_endpoint["domain"])
            ssh_port = ssh_endpoint["port"]
            ssh_id = ssh_endpoint["id"]

            role_idx = ssh_id.split("-")[-2]
            match = re.match("([a-z]+)([0-9]+)", role_idx)
            assert match is not None, "%s is not role index name" % role_idx

            role, idx = match.groups()

            # exec into jobmanager to execute ssh to avoid firewall
            job_manager_pod = utils.kube_get_pods(args.config, "default",
                                                  "app=jobmanager")[0]
            job_manager_pod_name = job_manager_pod.metadata.name

            alias = args.email.split("@")[0]

            ssh_cmd = [
                "ssh",
                "-i",
                "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
                "-p",
                str(ssh_port),
                "-o",
                "StrictHostKeyChecking=no",
                "-o",
                "LogLevel=ERROR",
                "%s@%s" % (alias, ssh_host),
                "--",
                "echo a; env | grep CUDA_VISIBLE_DEVICES;",
                "grep CUDA_VISIBLE_DEVICES ~/.ssh/environment; echo b",
            ]

            code, output = utils.kube_pod_exec(args.config, "default",
                                               job_manager_pod_name,
                                               "jobmanager", ssh_cmd)

            logger.debug("cmd %s code is %s, output is %s", " ".join(ssh_cmd),
                         code, output)

            if role == "ps":
                expected = "a\nb"
            else:
                expected = "a\nCUDA_VISIBLE_DEVICES=0,1,2,3\nCUDA_VISIBLE_DEVICES=0,1,2,3\nb"

            assert expected in output, "could not find %s in output %s" % (
                expected, output)
def test_regular_job_custom_ssh_key(args):
    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)
    with open("data/id_rsa.pub") as f:
        job_spec["ssh_public_keys"] = [f.read()]

    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 1
        endpoint_id = endpoints_ids[0]

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)
        logger.debug("endpoints resp is %s", ssh_endpoint)

        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]

        # exec into jobmanager to execute ssh to avoid firewall
        job_manager_pod = utils.kube_get_pods(args.config, "default",
                                              "app=jobmanager")[0]
        job_manager_pod_name = job_manager_pod.metadata.name

        alias = args.email.split("@")[0]

        dest = "/tmp/test_regular_job_customer_ssh_key"

        script_cmd = []

        with open("data/id_rsa") as f:
            script_cmd.append("rm %s ; " % dest)

            for line in f.readlines():
                script_cmd.append("echo")
                script_cmd.append(line.strip())
                script_cmd.append(">> %s ;" % dest)

            script_cmd.append("chmod 400 %s ;" % dest)

        cmd = ["sh", "-c", " ".join(script_cmd)]

        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)

        cmd = [
            "ssh", "-i", dest, "-p", ssh_port, "-o",
            "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR",
            "%s@%s" % (alias, ssh_host), "--", "echo", "dummy"
        ]
        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)
        assert output == "dummy\n", "output is %s" % (output)
def test_distributed_job_system_envs(args):
    envs = utils.load_system_envs(args)

    job_spec = utils.gen_default_job_description("distributed",
                                                 args.email,
                                                 args.uid,
                                                 args.vc,
                                                 cmd="sleep infinity")
    with utils.run_job(args.rest, job_spec) as job:
        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        for endpoint_id in endpoints_ids:
            ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                     job.jid, endpoint_id)
            logger.debug("endpoints resp is %s", ssh_endpoint)

            ssh_host = "%s.%s" % (ssh_endpoint["nodeName"],
                                  ssh_endpoint["domain"])
            ssh_port = ssh_endpoint["port"]
            ssh_id = ssh_endpoint["id"]

            bash_cmd = ";".join([
                "printf '%s=' ; printenv %s" % (key, key)
                for key, _ in envs.items()
            ])

            # exec into jobmanager to execute ssh to avoid firewall
            job_manager_pod = utils.kube_get_pods(args.config, "default",
                                                  "app=jobmanager")[0]
            job_manager_pod_name = job_manager_pod.metadata.name

            alias = args.email.split("@")[0]

            ssh_cmd = [
                "ssh",
                "-i",
                "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
                "-p",
                str(ssh_port),
                "-o",
                "StrictHostKeyChecking=no",
                "-o",
                "LogLevel=ERROR",
                "%s@%s" % (alias, ssh_host),
                "--",
            ]
            ssh_cmd.append(bash_cmd)

            code, output = utils.kube_pod_exec(args.config, "default",
                                               job_manager_pod_name,
                                               "jobmanager", ssh_cmd)

            logger.debug("cmd %s code is %s, output is %s", " ".join(ssh_cmd),
                         code, output)

            for key, val in envs.items():
                expected_output = "%s=%s" % (key, val)
                assert output.find(
                    expected_output) != -1, "could not find %s in log %s" % (
                        expected_output, output)
def test_regular_job_env(args):
    envs = {
        "DLWS_HOST_NETWORK": "",
        "DLTS_HOST_NETWORK": "",
        "DLWS_NUM_PS": "0",
        "DLTS_NUM_PS": "0",
        "DLWS_NUM_WORKER": "1",
        "DLTS_NUM_WORKER": "1",
        "DLWS_NUM_GPU_PER_WORKER": "0",
        "DLTS_NUM_GPU_PER_WORKER": "0",
        "DLWS_VC_NAME": str(args.vc),
        "DLTS_VC_NAME": str(args.vc),
        "DLWS_UID": str(args.uid),
        "DLTS_UID": str(args.uid),
        "DLWS_USER_NAME": args.email.split("@")[0],
        "DLTS_USER_NAME": args.email.split("@")[0],
        "DLWS_USER_EMAIL": args.email,
        "DLTS_USER_EMAIL": args.email,
        "DLWS_ROLE_NAME": "master",
        "DLTS_ROLE_NAME": "master",
        "DLWS_JOB_ID": "unknown",
        "DLTS_JOB_ID": "unknown",
    }

    job_spec = utils.gen_default_job_description("regular", args.email,
                                                 args.uid, args.vc)

    with utils.run_job(args.rest, job_spec) as job:
        envs["DLWS_JOB_ID"] = job.jid
        envs["DLTS_JOB_ID"] = job.jid

        endpoints = utils.create_endpoint(args.rest, args.email, job.jid,
                                          ["ssh"])
        endpoints_ids = list(endpoints.keys())
        assert len(endpoints_ids) == 1
        endpoint_id = endpoints_ids[0]

        state = job.block_until_state_not_in(
            {"unapproved", "queued", "scheduling"})
        assert state == "running"

        ssh_endpoint = utils.wait_endpoint_state(args.rest, args.email,
                                                 job.jid, endpoint_id)
        logger.debug("endpoints resp is %s", ssh_endpoint)

        ssh_host = "%s.%s" % (ssh_endpoint["nodeName"], ssh_endpoint["domain"])
        ssh_port = ssh_endpoint["port"]

        # exec into jobmanager to execute ssh to avoid firewall
        job_manager_pod = utils.kube_get_pods(args.config, "default",
                                              "app=jobmanager")[0]
        job_manager_pod_name = job_manager_pod.metadata.name

        alias = args.email.split("@")[0]

        bash_cmd = ";".join([
            "printf '%s=' ; printenv %s" % (key, key)
            for key, _ in envs.items()
        ])

        ssh_cmd = [
            "ssh",
            "-i",
            "/dlwsdata/work/%s/.ssh/id_rsa" % alias,
            "-p",
            ssh_port,
            "-o",
            "StrictHostKeyChecking=no",
            "-o",
            "LogLevel=ERROR",
            "%s@%s" % (alias, ssh_host),
            "--",
        ]
        ssh_cmd.append(bash_cmd)
        code, output = utils.kube_pod_exec(args.config, "default",
                                           job_manager_pod_name, "jobmanager",
                                           ssh_cmd)
        assert code == 0, "code is %s, output is %s" % (code, output)

        for key, val in envs.items():
            expected_output = "%s=%s" % (key, val)
            assert output.find(
                expected_output) != -1, "could not find %s in log %s" % (
                    expected_output, output)