예제 #1
0
def start_endpoints():
    try:
        data_handler = DataHandler()
        try:
            pending_endpoints = data_handler.GetPendingEndpoints()

            for endpoint_id, endpoint in pending_endpoints.items():
                try:
                    job = data_handler.GetJob(jobId=endpoint["jobId"])[0]
                    if job["jobStatus"] != "running":
                        continue

                    # get endpointDescriptionPath
                    # job["jobDescriptionPath"] = "jobfiles/" + time.strftime("%y%m%d") + "/" + jobParams["jobId"] + "/" + jobParams["jobId"] + ".yaml"
                    endpoint_description_dir = re.search(
                        "(.*/)[^/\.]+.yaml",
                        job["jobDescriptionPath"]).group(1)
                    endpoint["endpointDescriptionPath"] = os.path.join(
                        endpoint_description_dir, endpoint_id + ".yaml")

                    logger.info(
                        "\n\n\n\n\n\n----------------Begin to start endpoint %s",
                        endpoint["id"])
                    output = get_k8s_endpoint(
                        endpoint["endpointDescriptionPath"])
                    if (output != ""):
                        endpoint_description = json.loads(output)
                        endpoint["endpointDescription"] = endpoint_description
                        endpoint["status"] = "running"
                        pod = k8sUtils.GetPod("podName=" + endpoint["podName"])
                        if "items" in pod and len(pod["items"]) > 0:
                            endpoint["nodeName"] = pod["items"][0]["spec"][
                                "nodeName"]
                    else:
                        start_endpoint(endpoint)

                    endpoint["lastUpdated"] = datetime.datetime.now(
                    ).isoformat()
                    data_handler.UpdateEndpoint(endpoint)
                except Exception as e:
                    logger.warning(
                        "Process endpoint failed {}".format(endpoint),
                        exc_info=True)
        except Exception as e:
            logger.exception("start endpoint failed")
        finally:
            data_handler.Close()
    except Exception as e:
        logger.exception("close data handler failed")
예제 #2
0
def start_endpoints():
    try:
        data_handler = DataHandler()
        pending_endpoints = data_handler.GetPendingEndpoints()

        for endpoint_id, endpoint in pending_endpoints.items():
            job = data_handler.GetJob(jobId=endpoint["jobId"])[0]
            if job["jobStatus"] != "running":
                continue
            if not is_user_ready(endpoint["podName"]):
                continue

            # get endpointDescriptionPath
            # job["jobDescriptionPath"] = "jobfiles/" + time.strftime("%y%m%d") + "/" + jobParams["jobId"] + "/" + jobParams["jobId"] + ".yaml"
            endpoint_description_dir = re.search(
                "(.*/)[^/\.]+.yaml", job["jobDescriptionPath"]).group(1)
            endpoint["endpointDescriptionPath"] = os.path.join(
                endpoint_description_dir, endpoint_id + ".yaml")

            print("\n\n\n\n\n\n----------------Begin to start endpoint %s" %
                  endpoint["id"])
            output = get_k8s_endpoint(endpoint["endpointDescriptionPath"])
            if (output != ""):
                endpoint_description = json.loads(output)
                endpoint["endpointDescription"] = endpoint_description
                endpoint["status"] = "running"
                if endpoint["hostNetwork"]:
                    endpoint["port"] = endpoint_description["spec"]["ports"][
                        0]["port"]

                pod = k8sUtils.GetPod("podName=" + endpoint["podName"])
                if "items" in pod and len(pod["items"]) > 0:
                    endpoint["nodeName"] = pod["items"][0]["spec"]["nodeName"]
            else:
                start_endpoint(endpoint)

            endpoint["lastUpdated"] = datetime.datetime.now().isoformat()
            data_handler.UpdateEndpoint(endpoint)
    except Exception as e:
        traceback.print_exc()
    finally:
        pass
예제 #3
0
def start_endpoints_by_thread(pending_endpoints,data_handler,jobId):
    for endpoint_id, endpoint in pending_endpoints.items():
        try:
            with sql_lock:
                job = data_handler.GetJob(jobId=endpoint["jobId"])[0]
            if job["jobStatus"] != "running":
                continue

            # get endpointDescriptionPath
            # job["jobDescriptionPath"] = "jobfiles/" + time.strftime("%y%m%d") + "/" + jobParams["jobId"] + "/" + jobParams["jobId"] + ".yaml"
            endpoint_description_dir = re.search("(.*/)[^/\.]+.yaml", job["jobDescriptionPath"]).group(1)
            endpoint["endpointDescriptionPath"] = os.path.join(endpoint_description_dir, endpoint_id + ".yaml")

            logger.info("\n\n\n\n\n\n----------------Begin to start endpoint %s", endpoint["id"])
            output = get_k8s_endpoint(endpoint["endpointDescriptionPath"])
            if (output != ""):
                endpoint_description = json.loads(output)
                endpoint["endpointDescription"] = endpoint_description
                endpoint["port"] = int(endpoint["endpointDescription"]["spec"]["ports"][0]["nodePort"])
                start_endpoint(endpoint)
                logging.info("\n----------------done for start endpoint %s", endpoint["id"])
                if is_server_ready(endpoint):
                    endpoint["status"] = "running"
                    logging.info("\n----------------endpoint %s is now running", endpoint["id"])
                pod = k8sUtils.GetPod("podName=" + endpoint["podName"])
                if "items" in pod and len(pod["items"]) > 0:
                    endpoint["nodeName"] = pod["items"][0]["spec"]["nodeName"]
            else:
                # create NodePort
                create_node_port(endpoint)
                logging.info("\n----------------create service done for %s", endpoint["id"])

            endpoint["lastUpdated"] = datetime.datetime.now().isoformat()
            with sql_lock:
                data_handler.UpdateEndpoint(endpoint)
        except Exception as e:
            logger.warning("Process endpoint failed {}".format(endpoint), exc_info=True)
    return jobId
예제 #4
0
def launch_ps_dist_job(jobParams):
    jobId = jobParams["jobId"]
    workerPodInfo = k8sUtils.GetPod("distRole=worker,run=" + jobId)
    psPodInfo = k8sUtils.GetPod("distRole=ps,run=" + jobId)
    if "items" in workerPodInfo and len(workerPodInfo["items"]) == int(
            jobParams["numpsworker"]) and "items" in psPodInfo and len(
                psPodInfo["items"]) == int(jobParams["numps"]):
        podStatus = [
            k8sUtils.check_pod_status(pod)
            for pod in workerPodInfo["items"] + psPodInfo["items"]
        ]
        if all([status == "Running" for status in podStatus]):
            ps_pod_names = [
                pod["metadata"]["name"] for pod in psPodInfo["items"]
            ]
            worker_pod_names = [
                pod["metadata"]["name"] for pod in workerPodInfo["items"]
            ]

            ps_pod_ips = [pod["status"]["podIP"] for pod in psPodInfo["items"]]
            worker_pod_ips = [
                pod["status"]["podIP"] for pod in workerPodInfo["items"]
            ]

            ps_num = len(psPodInfo["items"])
            worker_num = len(workerPodInfo["items"])

            ps_ports = [
                int(item["metadata"]["labels"]["distPort"])
                for item in psPodInfo["items"]
            ]
            worker_ports = [
                int(item["metadata"]["labels"]["distPort"])
                for item in workerPodInfo["items"]
            ]

            #port range: 30000~31000
            #rndList = range(max(1000,ps_num + worker_num))
            #random.shuffle(rndList)
            #ps_ports = [rndList[i] + 30000 for i in range(ps_num)]
            #worker_ports = [rndList[i + ps_num] + 30000 for i in range(worker_num)]

            ps_hosts = ",".join([
                "%s:%s" % (ps_pod_ips[i], ps_ports[i]) for i in range(ps_num)
            ])
            worker_hosts = ",".join([
                "%s:%s" % (worker_pod_ips[i], worker_ports[i])
                for i in range(worker_num)
            ])

            ps_files = ["/tmp/" + str(uuid.uuid4()) for i in range(ps_num)]
            worker_files = [
                "/tmp/" + str(uuid.uuid4()) for i in range(worker_num)
            ]

            ps_cmd = [
                "%s --ps_hosts=%s --worker_hosts=%s --job_name=ps --task_index=%d 2>&1 | tee %s"
                % (jobParams["cmd"], ps_hosts, worker_hosts, i, ps_files[i])
                for i in range(ps_num)
            ]
            worker_cmd = [
                "%s --ps_hosts=%s --worker_hosts=%s --job_name=worker --task_index=%d 2>&1 | tee %s"
                %
                (jobParams["cmd"], ps_hosts, worker_hosts, i, worker_files[i])
                for i in range(worker_num)
            ]

            for i in range(ps_num):
                os.system("mkdir -p %s" % ps_files[i])
                ps_files[i] = os.path.join(ps_files[i], "run_dist_job.sh")
                with open(ps_files[i], 'w') as f:
                    f.write(ps_cmd[i] + "\n")
                f.close()
                if "userId" in jobParams:
                    os.system("chown -R %s %s" %
                              (jobParams["userId"], ps_files[i]))
                remotecmd = "cp %s %s:/opt/run_dist_job.sh" % (ps_files[i],
                                                               ps_pod_names[i])
                k8sUtils.kubectl_exec(remotecmd)
                k8sUtils.kubectl_exec("exec %s touch /opt/run_dist_job" %
                                      ps_pod_names[i])

            for i in range(worker_num):
                os.system("mkdir -p %s" % worker_files[i])
                worker_files[i] = os.path.join(worker_files[i],
                                               "run_dist_job.sh")
                with open(worker_files[i], 'w') as f:
                    f.write(worker_cmd[i] + "\n")
                f.close()
                if "userId" in jobParams:
                    os.system("chown -R %s %s" %
                              (jobParams["userId"], worker_files[i]))
                remotecmd = "cp %s %s:/opt/run_dist_job.sh" % (
                    worker_files[i], worker_pod_names[i])
                k8sUtils.kubectl_exec(remotecmd)
                k8sUtils.kubectl_exec("exec %s touch /opt/run_dist_job" %
                                      worker_pod_names[i])

            dataHandler = DataHandler()
            dataHandler.UpdateJobTextField(jobParams["jobId"], "jobStatus",
                                           "running")
예제 #5
0
def UpdateDistJobStatus(job):
    dataHandler = DataHandler()
    jobParams = json.loads(base64.b64decode(job["jobParams"]))

    if "userId" not in jobParams:
        jobParams["userId"] = "0"

    jobPath, workPath, dataPath = GetStoragePath(jobParams["jobPath"],
                                                 jobParams["workPath"],
                                                 jobParams["dataPath"])
    localJobPath = os.path.join(config["storage-mount-path"], jobPath)
    logPath = os.path.join(localJobPath, "logs/joblog.txt")

    result, detail = k8sUtils.GetJobStatus(job["jobId"])
    dataHandler.UpdateJobTextField(job["jobId"], "jobStatusDetail",
                                   base64.b64encode(detail))

    logging.info("job %s status: %s,%s" %
                 (job["jobId"], result, json.dumps(detail)))

    jobDescriptionPath = os.path.join(
        config["storage-mount-path"],
        job["jobDescriptionPath"]) if "jobDescriptionPath" in job else None

    jobId = jobParams["jobId"]
    workerPodInfo = k8sUtils.GetPod("distRole=worker,run=" + jobId)
    psPodInfo = k8sUtils.GetPod("distRole=ps,run=" + jobId)
    if "items" in workerPodInfo and len(workerPodInfo["items"]) == int(
            jobParams["numpsworker"]) and "items" in psPodInfo and len(
                psPodInfo["items"]) == int(jobParams["numps"]):
        if job["jobStatus"] == "scheduling":
            launch_ps_dist_job(jobParams)
        if job["jobStatus"] == "running":
            result, detail = GetDistJobStatus(job["jobId"])
            dataHandler.UpdateJobTextField(job["jobId"], "jobStatusDetail",
                                           base64.b64encode(detail))

            printlog("job %s status: %s" % (job["jobId"], result))

            jobDescriptionPath = os.path.join(
                config["storage-mount-path"], job["jobDescriptionPath"]
            ) if "jobDescriptionPath" in job else None

            if result.strip() == "Succeeded":
                joblog_manager.extract_job_log(job["jobId"], logPath,
                                               jobParams["userId"])
                dataHandler.UpdateJobTextField(job["jobId"], "jobStatus",
                                               "finished")
                if jobDescriptionPath is not None and os.path.isfile(
                        jobDescriptionPath):
                    k8sUtils.kubectl_delete(jobDescriptionPath)

            elif result.strip() == "Running":
                joblog_manager.extract_job_log(job["jobId"], logPath,
                                               jobParams["userId"])
                if job["jobStatus"] != "running":
                    dataHandler.UpdateJobTextField(job["jobId"], "jobStatus",
                                                   "running")
                if "interactivePort" in jobParams:
                    serviceAddress = k8sUtils.GetServiceAddress(job["jobId"])
                    serviceAddress = base64.b64encode(
                        json.dumps(serviceAddress))
                    dataHandler.UpdateJobTextField(job["jobId"], "endpoints",
                                                   serviceAddress)

            elif result.strip() == "Failed":
                printlog("Job %s fails, cleaning..." % job["jobId"])
                joblog_manager.extract_job_log(job["jobId"], logPath,
                                               jobParams["userId"])
                dataHandler.UpdateJobTextField(job["jobId"], "jobStatus",
                                               "failed")
                dataHandler.UpdateJobTextField(job["jobId"], "errorMsg",
                                               detail)
                if jobDescriptionPath is not None and os.path.isfile(
                        jobDescriptionPath):
                    k8sUtils.kubectl_delete(jobDescriptionPath)

            elif result.strip() == "Unknown":
                if job["jobId"] not in UnusualJobs:
                    UnusualJobs[job["jobId"]] = datetime.datetime.now()
                elif (datetime.datetime.now() -
                      UnusualJobs[job["jobId"]]).seconds > 300:
                    del UnusualJobs[job["jobId"]]
                    retries = dataHandler.AddandGetJobRetries(job["jobId"])
                    if retries >= 5:
                        printlog("Job %s fails for more than 5 times, abort" %
                                 job["jobId"])
                        dataHandler.UpdateJobTextField(job["jobId"],
                                                       "jobStatus", "error")
                        dataHandler.UpdateJobTextField(
                            job["jobId"], "errorMsg", "cannot launch the job.")
                        if jobDescriptionPath is not None and os.path.isfile(
                                jobDescriptionPath):
                            k8sUtils.kubectl_delete(jobDescriptionPath)
                    else:
                        printlog(
                            "Job %s fails in Kubernetes, delete and re-submit the job. Retries %d"
                            % (job["jobId"], retries))
                        SubmitJob(job)

            if result.strip() != "Unknown" and job["jobId"] in UnusualJobs:
                del UnusualJobs[job["jobId"]]

    pass
예제 #6
0
def launch_ps_dist_job(jobParams):
    job_id = jobParams["jobId"]
    pods = k8sUtils.GetPod("run=" + job_id)

    # if any pod is not up, return
    if "items" not in pods or len(pods["items"]) != (
            int(jobParams["numpsworker"]) + int(jobParams["numps"])):
        return
    # if any pod is not ready, return
    pod_status = [k8sUtils.check_pod_status(pod) for pod in pods["items"]]
    if any([status != "Running" for status in pod_status]):
        return

    user_name = getAlias(jobParams["userName"])
    if "hostNetwork" in jobParams and jobParams["hostNetwork"]:
        host_network = True
    else:
        host_network = False

    # setup ssh server
    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        dist_port = pod["metadata"]["labels"]["distPort"]
        # quit if can't setup ssh server
        ssh_port = start_ssh_server(pod_name, user_name, host_network,
                                    dist_port)

    # generate ssh config
    ssh_config = """
Host %s
  HostName %s
  Port %s
  User %s
  StrictHostKeyChecking no
  UserKnownHostsFile /dev/null
                """
    sshconfigstr = ""
    for [idx, pod] in enumerate(pods["items"]):
        pod_ip = pod["status"]["podIP"]
        dist_port = pod["metadata"]["labels"]["distPort"]
        role = pod["metadata"]["labels"]["distRole"]
        role_idx = pod["metadata"]["labels"]["distRoleIdx"]

        # TODO hostNetwork
        if host_network:
            sshconfigstr += (
                ssh_config %
                (role + "-" + str(role_idx), pod_ip, str(dist_port), user_name)
                + "\n")
        else:
            sshconfigstr += (
                ssh_config %
                (role + "-" + str(role_idx), pod_ip, 22, user_name) + "\n")

    # config ssh client
    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        bash_script = "cat > /home/" + user_name + "/.ssh/config <<EOF " + sshconfigstr + "\nEOF"
        print("override ssh client config: %s" % bash_script)
        k8sUtils.kubectl_exec(
            "exec %s -- bash -c \'%s\' ; chown -R %s /home/%s/.ssh/config" %
            (pod_name, bash_script, user_name, user_name))

        # fix ~/.ssh/ folder permission
        k8sUtils.kubectl_exec(
            "exec %s -- chmod 600 -R /home/%s/.ssh; chmod 700 /home/%s/.ssh; chown -R %s /home/%s/.ssh/config"
            % (pod_name, user_name, user_name, user_name, user_name))

    # generate hostfile
    hostfilecontent = ""
    for [_, pod] in enumerate(pods["items"]):
        role = pod["metadata"]["labels"]["distRole"]
        if role == "ps":
            continue
        role_idx = pod["metadata"]["labels"]["distRoleIdx"]
        worker_gpu_num = pod["spec"]["containers"][0]["resources"]["requests"][
            "nvidia.com/gpu"]
        hostfilecontent += "%s  slots=%s\n" % ("worker-" + str(role_idx),
                                               worker_gpu_num)
    tmp_hostfile = "/tmp/" + job_id + ".hostfile"
    with open(tmp_hostfile, 'w') as f:
        f.write(hostfilecontent + "\n")
    # write the hostfile
    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        remotecmd = "cp %s %s:/job/hostfile" % (tmp_hostfile, pod_name)
        k8sUtils.kubectl_exec(remotecmd)

    for [idx, pod] in enumerate(pods["items"]):
        pod_name = pod["metadata"]["name"]
        k8sUtils.kubectl_exec("exec %s touch /opt/run_dist_job" % pod_name)

    # execute user command
    #k8sUtils.kubectl_exec("exec %s -- bash -c 'runuser -l ${DLWS_USER_NAME} <<EOF_USER_SCRIPT %s \nEOF_USER_SCRIPT'" % (pod_name, jobParams["cmd"]))

    # update job status
    dataHandler = DataHandler()
    dataHandler.UpdateJobTextField(job_id, "jobStatus", "running")
    dataHandler.Close()