Exemplo n.º 1
0
def launch_ps_dist_job(jobParams):
    jobId = jobParams["jobId"]
    workerPodInfo = k8sUtils.GetPod("distRole=worker,run=" + jobId)
    psPodInfo = k8sUtils.GetPod("distRole=ps,run=" + jobId)
    if "items" in workerPodInfo and len(workerPodInfo["items"]) == int(
            jobParams["numpsworker"]) and "items" in psPodInfo and len(
                psPodInfo["items"]) == int(jobParams["numps"]):
        podStatus = [
            k8sUtils.check_pod_status(pod)
            for pod in workerPodInfo["items"] + psPodInfo["items"]
        ]
        if all([status == "Running" for status in podStatus]):
            ps_pod_names = [
                pod["metadata"]["name"] for pod in psPodInfo["items"]
            ]
            worker_pod_names = [
                pod["metadata"]["name"] for pod in workerPodInfo["items"]
            ]

            ps_pod_ips = [pod["status"]["podIP"] for pod in psPodInfo["items"]]
            worker_pod_ips = [
                pod["status"]["podIP"] for pod in workerPodInfo["items"]
            ]

            ps_num = len(psPodInfo["items"])
            worker_num = len(workerPodInfo["items"])

            ps_ports = [
                int(item["metadata"]["labels"]["distPort"])
                for item in psPodInfo["items"]
            ]
            worker_ports = [
                int(item["metadata"]["labels"]["distPort"])
                for item in workerPodInfo["items"]
            ]

            #port range: 30000~31000
            #rndList = range(max(1000,ps_num + worker_num))
            #random.shuffle(rndList)
            #ps_ports = [rndList[i] + 30000 for i in range(ps_num)]
            #worker_ports = [rndList[i + ps_num] + 30000 for i in range(worker_num)]

            ps_hosts = ",".join([
                "%s:%s" % (ps_pod_ips[i], ps_ports[i]) for i in range(ps_num)
            ])
            worker_hosts = ",".join([
                "%s:%s" % (worker_pod_ips[i], worker_ports[i])
                for i in range(worker_num)
            ])

            ps_files = ["/tmp/" + str(uuid.uuid4()) for i in range(ps_num)]
            worker_files = [
                "/tmp/" + str(uuid.uuid4()) for i in range(worker_num)
            ]

            ps_cmd = [
                "%s --ps_hosts=%s --worker_hosts=%s --job_name=ps --task_index=%d 2>&1 | tee %s"
                % (jobParams["cmd"], ps_hosts, worker_hosts, i, ps_files[i])
                for i in range(ps_num)
            ]
            worker_cmd = [
                "%s --ps_hosts=%s --worker_hosts=%s --job_name=worker --task_index=%d 2>&1 | tee %s"
                %
                (jobParams["cmd"], ps_hosts, worker_hosts, i, worker_files[i])
                for i in range(worker_num)
            ]

            for i in range(ps_num):
                os.system("mkdir -p %s" % ps_files[i])
                ps_files[i] = os.path.join(ps_files[i], "run_dist_job.sh")
                with open(ps_files[i], 'w') as f:
                    f.write(ps_cmd[i] + "\n")
                f.close()
                if "userId" in jobParams:
                    os.system("chown -R %s %s" %
                              (jobParams["userId"], ps_files[i]))
                remotecmd = "cp %s %s:/opt/run_dist_job.sh" % (ps_files[i],
                                                               ps_pod_names[i])
                k8sUtils.kubectl_exec(remotecmd)
                k8sUtils.kubectl_exec("exec %s touch /opt/run_dist_job" %
                                      ps_pod_names[i])

            for i in range(worker_num):
                os.system("mkdir -p %s" % worker_files[i])
                worker_files[i] = os.path.join(worker_files[i],
                                               "run_dist_job.sh")
                with open(worker_files[i], 'w') as f:
                    f.write(worker_cmd[i] + "\n")
                f.close()
                if "userId" in jobParams:
                    os.system("chown -R %s %s" %
                              (jobParams["userId"], worker_files[i]))
                remotecmd = "cp %s %s:/opt/run_dist_job.sh" % (
                    worker_files[i], worker_pod_names[i])
                k8sUtils.kubectl_exec(remotecmd)
                k8sUtils.kubectl_exec("exec %s touch /opt/run_dist_job" %
                                      worker_pod_names[i])

            dataHandler = DataHandler()
            dataHandler.UpdateJobTextField(jobParams["jobId"], "jobStatus",
                                           "running")
Exemplo n.º 2
0
def launch_ps_dist_job(jobParams):
    job_id = jobParams["jobId"]
    pods = k8sUtils.GetPod("run=" + job_id)

    # if any pod is not up, return
    if "items" not in pods or len(pods["items"]) != (
            int(jobParams["numpsworker"]) + int(jobParams["numps"])):
        return
    # if any pod is not ready, return
    pod_status = [k8sUtils.check_pod_status(pod) for pod in pods["items"]]
    if any([status != "Running" for status in pod_status]):
        return

    user_name = getAlias(jobParams["userName"])
    if "hostNetwork" in jobParams and jobParams["hostNetwork"]:
        host_network = True
    else:
        host_network = False

    # setup ssh server
    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        dist_port = pod["metadata"]["labels"]["distPort"]
        # quit if can't setup ssh server
        ssh_port = start_ssh_server(pod_name, user_name, host_network,
                                    dist_port)

    # generate ssh config
    ssh_config = """
Host %s
  HostName %s
  Port %s
  User %s
  StrictHostKeyChecking no
  UserKnownHostsFile /dev/null
                """
    sshconfigstr = ""
    for [idx, pod] in enumerate(pods["items"]):
        pod_ip = pod["status"]["podIP"]
        dist_port = pod["metadata"]["labels"]["distPort"]
        role = pod["metadata"]["labels"]["distRole"]
        role_idx = pod["metadata"]["labels"]["distRoleIdx"]

        # TODO hostNetwork
        if host_network:
            sshconfigstr += (
                ssh_config %
                (role + "-" + str(role_idx), pod_ip, str(dist_port), user_name)
                + "\n")
        else:
            sshconfigstr += (
                ssh_config %
                (role + "-" + str(role_idx), pod_ip, 22, user_name) + "\n")

    # config ssh client
    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        bash_script = "cat > /home/" + user_name + "/.ssh/config <<EOF " + sshconfigstr + "\nEOF"
        print("override ssh client config: %s" % bash_script)
        k8sUtils.kubectl_exec(
            "exec %s -- bash -c \'%s\' ; chown -R %s /home/%s/.ssh/config" %
            (pod_name, bash_script, user_name, user_name))

        # fix ~/.ssh/ folder permission
        k8sUtils.kubectl_exec(
            "exec %s -- chmod 600 -R /home/%s/.ssh; chmod 700 /home/%s/.ssh; chown -R %s /home/%s/.ssh/config"
            % (pod_name, user_name, user_name, user_name, user_name))

    # generate hostfile
    hostfilecontent = ""
    for [_, pod] in enumerate(pods["items"]):
        role = pod["metadata"]["labels"]["distRole"]
        if role == "ps":
            continue
        role_idx = pod["metadata"]["labels"]["distRoleIdx"]
        worker_gpu_num = pod["spec"]["containers"][0]["resources"]["requests"][
            "nvidia.com/gpu"]
        hostfilecontent += "%s  slots=%s\n" % ("worker-" + str(role_idx),
                                               worker_gpu_num)
    tmp_hostfile = "/tmp/" + job_id + ".hostfile"
    with open(tmp_hostfile, 'w') as f:
        f.write(hostfilecontent + "\n")
    # write the hostfile
    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        remotecmd = "cp %s %s:/job/hostfile" % (tmp_hostfile, pod_name)
        k8sUtils.kubectl_exec(remotecmd)

    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        k8sUtils.kubectl_exec("exec %s touch /opt/run_dist_job" % pod_name)

    # execute user command
    #k8sUtils.kubectl_exec("exec %s -- bash -c 'runuser -l ${DLWS_USER_NAME} <<EOF_USER_SCRIPT %s \nEOF_USER_SCRIPT'" % (pod_name, jobParams["cmd"]))

    # update job status
    dataHandler = DataHandler()
    dataHandler.UpdateJobTextField(job_id, "jobStatus", "running")
    dataHandler.Close()