示例#1
0
 def download_compressed(self, url, compression='tar', uncompress=True, delete_after_uncompress=False, dir=None,
                         api_version=1):
     """
     Download and optionally uncompress the tar file from the given url
     """
     if dir:
         if os.path.exists(dir):
             raise ExistedException
         else:
             os.mkdir(dir)
             os.chdir(dir)
     try:
         logger.info("Downloading the tar file to the current directory ...")
         filename = self.download(url=url, filename='output', api_version=api_version)
         if filename and os.path.isfile(filename) and uncompress:
             logger.info("Uncompressring the contents of the file ...")
             if compression == 'tar':
                 tar = tarfile.open(filename)
                 tar.extractall()
                 tar.close()
             elif compression == 'zip':
                 zip = zipfile.ZipFile(filename)
                 zip.extractall()
                 zip.close()
         if delete_after_uncompress:
             logger.info("Cleaning up the compressed file ...")
             os.remove(filename)
         return filename
     except requests.exceptions.ConnectionError as e:
         logger.error("Download ERROR! {}".format(e))
         return False
示例#2
0
def process_data_ids(data):
    # TODO
    if len(data) > 5:
        logger.error(
            "Cannot attach more than 5 datasets to a task")
        return False, None
    # Get the data entity from the server to:
    # 1. Confirm that the data id or uri exists and has the right permissions
    # 2. If uri is used, get the id of the dataset
    data_ids = []
    mc = DataClient()
    for data_id_and_path in data:
        if ':' in data_id_and_path:
            data_id, path = data_id_and_path.split(':')
        else:
            data_id = data_id_and_path
            path = None
        data_obj = mc.get(data_id)
        if not data_obj:
            logger.error("Data not found by id: {}".format(data_id))
            return False, None
        else:
            if path is None:
                path = "{}-{}".format(data_obj.name, data_obj.version)
            data_ids.append("{id}:{path}".format(id=data_obj.id, path=path))

    return True, data_ids
示例#3
0
    def create(self, user_name, project_name):
        response = None
        try:
            response = self.request(method="PUT",
                                    url=self.project_api_url.format(
                                        user_name=user_name,
                                        project_name=project_name)
                                    )

        except Exception as e:
            logger.error("Create remote project failed, reason: {}".format(str(e)))
        return response
示例#4
0
def delete(id, yes):
    """
    Delete data set.
    """
    data_source = DataClient().get(id)

    if not yes:
        click.confirm('Delete Data: {}?'.format(data_source.name), abort=True, default=False)

    if DataClient().delete(id):
        cl_logger.info("Data deleted")
    else:
        cl_logger.error("Failed to delete data")
示例#5
0
def init(id, name):
    """
    Initialize new project at the current dir.

        russell init --name test_name

    or

        russell init --id 151af60026cd462792fa5d77ef79be4d
    """
    if not id and not name:
        logger.warning("Neither id or name offered\n{}".format(init.__doc__))
        return
    RussellIgnoreManager.init()
    try:
        pc = ProjectClient()
    except Exception as e:
        logger.error(str(e))
        return

    access_token = AuthConfigManager.get_access_token()
    project_info = {}
    try:
        if id:
            project_info = pc.get_project_info_by_id(id=id)
        elif name:
            project_info = pc.get_project_info_by_name(access_token.username,
                                                       name)
    except Exception as e:
        logger.error(str(e))
        return

    else:
        if AuthClient().get_user(
                access_token.token).uid != project_info.get('owner_id'):
            logger.info("You can create a project then run 'russell init'")
            return
        project_id = project_info.get('id')
        name = project_info.get('name', '')
        if project_id:
            experiment_config = dict(name=name, project_id=project_id)
            ExperimentConfigManager.set_config(experiment_config)
            logger.info(
                "Project \"{}\" initialized in current directory".format(name))
        else:
            logger.error(
                "Project \"{}\" initialization failed in current directory".
                format(name))
示例#6
0
def run(resubmit, command, env, jupyter, tensorboard, data, version, message, os, cputype, cpunum, gputype, gpunum,
        memtype, memnum, eager, value, earliest, deadline, duration):
    '''

    :param resubmit:
    :param command:
    :param env:
    :param jupyter:
    :param tensorboard:
    :param data:
    :param version:
    :param message:
    :param os:
    :param cputype:
    :param cpunum:
    :param gputype:
    :param gpunum:
    :param memtype:
    :param memnum:
    :param eager:
    :param value:
    :param earliest:
    :param deadline:
    :param duration:
    :return:
    '''
    """
    """
    # 初始化客户端
    try:
        ec = ExperimentClient()
    except Exception as e:
        logger.error(str(e))
        return
    if resubmit is True:
        # 只关注竞价部分的参数
        jobSpec = {}  # 从本地配置文件或者服务器读取上次竞价失败的(或者本地配置文件中的,上次竞价成功的也行)作业详情
        jobId = jobSpec["id"]
        # 提交作业请求
        jobReq = JobReq(duration=duration, tw_end=deadline, tw_start=earliest, job_id=jobId, value=value,
                        resources=jobSpec["resources"])
        resp = ec.submit(jobId, jobReq)
        if resp["accepted"] == False:
            logger.info("This job submit is not accepted, reason: {}".format(resp["message"]))
            return
    # 检查备注信息长度
    if message and len(message) > 1024:
        logger.error("Message body length over limit")
        return

    # 获取认证令牌
    access_token = AuthConfigManager.get_access_token()
    # 读取本地作业配置信息
    experiment_config = ExperimentConfigManager.get_config()

    # 组装命令成列表
    command_str = ' '.join(command)
    # # 处理挂载数据集
    # success, data_ids = process_data_ids(data)
    # if not success:
    #     return

    # 处理深度学习框架配置
    if not env:
        # 未指定,获取作业所属项目的默认框架作为此次作业的框架
        env = ProjectClient().get_project_info_by_id(experiment_config["project_id"]).get('default_env')

    # 检查所有资源的组合是否合法
    if not validate_resource_list(env, jupyter, tensorboard, os, cputype, cpunum, gputype, gpunum):
        return

    # 上传代码到云端或者指定云端代码
    # # 如果指定了代码版本
    # if version:
    #     module_resp = ModuleClient().get_by_entity_id_version(experiment_config.project_id, version)
    #     if not module_resp:
    #         logger.error("Remote project does not existed")
    #         return
    #     module_id = module_resp.get('id')
    # else:
    #     # Gen temp dir
    #     try:
    #         # upload_files, total_file_size_fmt, total_file_size = get_files_in_directory('.', 'code')
    #         # save_dir(upload_files, _TEMP_DIR)
    #         file_count, size = get_files_in_current_directory('code')
    #         if size > 100 * 1024 * 1024:
    #             sys.exit("Total size: {}. "
    #                      "Code size too large to sync, please keep it under 100MB."
    #                      "If you have data files in the current directory, please upload them "
    #                      "separately using \"russell data\" command and remove them from here.\n".format(
    #                 sizeof_fmt(size)))
    #         copy_files('.', _TEMP_DIR)
    #     except OSError:
    #         sys.exit("Directory contains too many files to upload. Add unused directories to .russellignore file.")
    #         # logger.info("Creating project run. Total upload size: {}".format(total_file_size_fmt))
    #         # logger.debug("Creating module. Uploading: {} files".format(len(upload_files)))
    #
    #     hash_code = dirhash(_TEMP_DIR)
    #     logger.debug("Checking MD5 ...")
    #     module_resp = ModuleClient().get_by_codehash_entity_id(hash_code, experiment_config.project_id)
    #     if module_resp:  # if code same with older version, use existed, don`t need upload
    #         module_id = module_resp.get('id')
    #         version = module_resp.get('version')
    #         logger.info("Use older version-{}.".format(version))
    #     else:
    #         version = experiment_config.version
    #         # Create module
    #         module = Module(name=experiment_config.name,
    #                         description=message,
    #                         family_id=experiment_config.family_id,
    #                         version=version,
    #                         module_type="code",
    #                         entity_id=experiment_config.project_id
    #                         )
    #         module_resp = mc.create(module)
    #         if not module_resp:
    #             logger.error("Remote project does not existed")
    #             return
    #         version = module_resp.get('version')
    #         experiment_config.set_version(version=version)
    #         ExperimentConfigManager.set_config(experiment_config)
    #
    #         module_id = module_resp.get('id')
    #         project_id = module_resp.get('entity_id')
    #         if not project_id == experiment_config.project_id:
    #             logger.error("Project conflict")
    #
    #         logger.debug("Created module with id : {}".format(module_id))
    #
    #         # Upload code to fs
    #         logger.info("Syncing code ...")
    #         fc = FsClient()
    #         try:
    #             fc.socket_upload(file_type="code",
    #                              filename=_TEMP_DIR,
    #                              access_token=access_token.token,
    #                              file_id=module_id,
    #                              user_name=access_token.username,
    #                              data_name=experiment_config.name)
    #         except Exception as e:
    #             shutil.rmtree(_TEMP_DIR)
    #             logger.error("Upload failed: {}".format(str(e)))
    #             return
    #         else:
    #             ### check socket state, some errors like file-server down, cannot be catched by `except`
    #             state = fc.get_state()
    #             if state == SOCKET_STATE.FAILED:
    #                 logger.error("Upload failed, please try after a while...")
    #                 return
    #         finally:
    #             try:
    #                 shutil.rmtree(fc.temp_dir)
    #             except FileNotFoundError:
    #                 pass
    #
    #         ModuleClient().update_codehash(module_id, hash_code)
    #         logger.info("\nUpload finished")
    #
    #     # rm temp dir
    #     shutil.rmtree(_TEMP_DIR)
    #     logger.debug("Created code with id : {}".format(module_id))

    # 创建作业描述指标
    jobSpecification = JobSpecification(message=message, code_id="", data_ids=[],
                                        command=command_str,
                                        project_id=experiment_config["project_id"],
                                        framework=env,
                                        enable_jupyter=jupyter,
                                        enable_tensorboard=tensorboard,
                                        os="ubuntu:16",
                                        gpunum=gpunum,
                                        gputype=gputype,
                                        cpunum=cpunum,
                                        cputype=cputype,
                                        memnum=memnum,
                                        memtype=memtype)
    # 提交该作业描述,由服务器保存
    jobId = ec.create(jobSpecification)
    logger.debug("Created job specification : {}".format(jobId))

    # # 更新本地作业配置
    # experiment_config.set_experiment_predecessor(experiment_id)
    # ExperimentConfigManager.set_config(experiment_config)

    # 打印作业描述信息
    experiment_name = "{}/{}:{}".format(access_token.username,
                                        experiment_config["project_id"],
                                        version)

    table_output = [["JOB ID", "NAME", "VERSION"],
                    [jobId, experiment_name, version]]
    logger.info(tabulate(table_output, headers="firstrow"))
    logger.info("")

    # 提交作业请求
    jobReq = JobReq(duration=duration, tw_end=deadline, tw_start=earliest, job_id=jobId, value=value,
                    resources=jobSpecification.resources)
    resp = ec.submit(jobId, jobReq)
    if resp["accepted"] == False:
        logger.info("This job submit is not accepted, reason: {}".format(resp["message"]))
        return

    # 作业成功提交后,处理jupyter/tensorboard
    task_url = {}
    if jupyter is True:
        while True:
            # Wait for the experiment / task instances to become available
            try:
                experiment = ec.get(jobId)
                if experiment.state != "waiting" and experiment.task_instances:
                    break
            except Exception as e:
                logger.debug("Experiment not available yet: {}".format(jobId))

            logger.debug("Experiment not available yet: {}".format(jobId))
            sleep(1)
            continue

        task_url = ec.get_task_url(jobId)
        jupyter_url = task_url["jupyter_url"]
        print("Setting up your instance and waiting for Jupyter notebook to become available ...")
        if wait_for_url(jupyter_url, sleep_duration_seconds=2, iterations=900):
            logger.info("\nPath to jupyter notebook: {}".format(jupyter_url))
            webbrowser.open(jupyter_url)
        else:
            logger.info("\nPath to jupyter notebook: {}".format(jupyter_url))
            logger.info(
                "Notebook is still loading or can not be connected now. View logs to track progress")

    if tensorboard is True:
        if not task_url.get("tensorboard_url"):
            task_url = ec.get_task_url(jobId)
        tensorboard_url = task_url["tensorboard_url"]
        logger.info("\nPath to tensorboard: {}".format(tensorboard_url))

    logger.info("""
        To view logs enter:
            ch logs {}
                """.format(jobId))