예제 #1
0
def get_predict_dsl():
    request_data = request.json
    request_data['query_filters'] = ['inference_dsl']
    retcode, retmsg, data = model_utils.query_model_info_from_file(
        **request_data)
    if data:
        for d in data:
            if d.get("f_role") in {"guest", "host"}:
                _data = d
                break
        else:
            return error_response(
                210,
                "can not found guest or host model, please get predict dsl on guest or host."
            )
        if request_data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY,
                                         request_data.get("filename"))
            with open(temp_filepath, "w") as fout:
                fout.write(json_dumps(_data['f_inference_dsl'], indent=4))
            return send_file(open(temp_filepath, "rb"),
                             as_attachment=True,
                             attachment_filename=request_data.get("filename"))
        else:
            return get_json_result(data=_data['f_inference_dsl'])
    return error_response(
        210,
        "No model found, please check if arguments are specified correctly.")
예제 #2
0
def component_output_data_download():
    request_data = request.json
    tasks = JobSaver.query_task(only_latest=True,
                                job_id=request_data['job_id'],
                                component_name=request_data['component_name'],
                                role=request_data['role'],
                                party_id=request_data['party_id'])
    if not tasks:
        raise ValueError(
            f'no found task, please check if the parameters are correct:{request_data}'
        )
    import_component_output_depend(tasks[0].f_provider_info)
    try:
        output_tables_meta = get_component_output_tables_meta(
            task_data=request_data)
    except Exception as e:
        stat_logger.exception(e)
        return error_response(210, str(e))
    limit = request_data.get('limit', -1)
    if not output_tables_meta:
        return error_response(response_code=210, retmsg='no data')
    if limit == 0:
        return error_response(response_code=210, retmsg='limit is 0')
    tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(
        request_data['job_id'], request_data['component_name'],
        request_data['role'], request_data['party_id'])
    return TableStorage.send_table(output_tables_meta,
                                   tar_file_name,
                                   limit=limit,
                                   need_head=request_data.get("head", True))
예제 #3
0
def validate_component_param():
    if not request.json or not isinstance(request.json, dict):
        return error_response(400, 'bad request')

    required_keys = [
        'component_name',
        'component_module_name',
    ]
    config_keys = ['role']

    dsl_version = int(request.json.get('dsl_version', 0))
    if dsl_version == 1:
        config_keys += ['role_parameters', 'algorithm_parameters']
        parser_class = DSLParser
    elif dsl_version == 2:
        config_keys += ['component_parameters']
        parser_class = DSLParserV2
    else:
        return error_response(400, 'unsupported dsl_version')

    try:
        check_config(request.json, required_keys + config_keys)
    except Exception as e:
        return error_response(400, str(e))

    try:
        parser_class.validate_component_param(
            get_federatedml_setting_conf_directory(),
            {i: request.json[i]
             for i in config_keys}, *[request.json[i] for i in required_keys])
    except Exception as e:
        return error_response(400, str(e))

    return get_json_result()
예제 #4
0
파일: tracking_app.py 프로젝트: tarada/FATE
def get_component_summary():
    request_data = request.json
    try:
        required_params = ["job_id", "component_name", "role", "party_id"]
        detect_utils.check_config(request_data, required_params)
        tracker = Tracker(job_id=request_data["job_id"],
                          component_name=request_data["component_name"],
                          role=request_data["role"],
                          party_id=request_data["party_id"],
                          task_id=request_data.get("task_id", None),
                          task_version=request_data.get("task_version", None))
        summary = tracker.read_summary_from_db()
        if summary:
            if request_data.get("filename"):
                temp_filepath = os.path.join(TEMP_DIRECTORY,
                                             request_data.get("filename"))
                with open(temp_filepath, "w") as fout:
                    fout.write(json.dumps(summary, indent=4))
                return send_file(
                    open(temp_filepath, "rb"),
                    as_attachment=True,
                    attachment_filename=request_data.get("filename"))
            else:
                return get_json_result(data=summary)
        return error_response(
            210,
            "No component summary found, please check if arguments are specified correctly."
        )
    except Exception as e:
        stat_logger.exception(e)
        return error_response(210, str(e))
예제 #5
0
def get_config():
    kwargs = {}
    job_configuration = None

    for i in ('job_id', 'role', 'party_id'):
        if request.json.get(i) is None:
            return error_response(400, f"'{i}' is required.")
        kwargs[i] = str(request.json[i])

    for i in ('component_name', 'task_id', 'task_version'):
        if request.json.get(i) is None:
            break
        kwargs[i] = str(request.json[i])
    else:
        try:
            job_configuration = job_utils.get_task_using_job_conf(**kwargs)
        except Exception:
            pass

    if job_configuration is None:
        job_configuration = job_utils.get_job_configuration(
            kwargs['job_id'], kwargs['role'], kwargs['party_id'])

    if job_configuration is None:
        return error_response(404, 'Job not found.')

    return get_json_result(data=job_configuration.to_dict())
예제 #6
0
def transfer_model():
    party_model_id = request.json.get('namespace')
    model_version = request.json.get('name')
    if not party_model_id or not model_version:
        return error_response(400, 'namespace and name are required')
    model_data = publish_model.download_model(party_model_id, model_version)
    if model_data is None:
        return error_response(404, 'model not found')
    return get_json_result(data=model_data)
예제 #7
0
def component_output_data_download():
    request_data = request.json
    output_data_table = get_component_output_data_table(task_data=request_data)
    limit = request_data.get('limit', -1)
    if not output_data_table:
        return error_response(response_code=500, retmsg='no data')
    if limit == 0:
        return error_response(response_code=500, retmsg='limit is 0')
    output_data_count = 0
    have_data_label = False
    output_tmp_dir = os.path.join(os.getcwd(), 'tmp/{}'.format(fate_uuid()))
    output_file_path = '{}/output_%s'.format(output_tmp_dir)
    output_data_file_path = output_file_path % 'data.csv'
    os.makedirs(os.path.dirname(output_data_file_path), exist_ok=True)
    with open(output_data_file_path, 'w') as fw:
        for k, v in output_data_table.collect():
            data_line, have_data_label = get_component_output_data_line(
                src_key=k, src_value=v)
            fw.write('{}\n'.format(','.join(map(lambda x: str(x), data_line))))
            output_data_count += 1
            if output_data_count == limit:
                break

    if output_data_count:
        # get meta
        header = get_component_output_data_meta(
            output_data_table=output_data_table,
            have_data_label=have_data_label)
        output_data_meta_file_path = output_file_path % 'data_meta.json'
        with open(output_data_meta_file_path, 'w') as fw:
            json.dump({'header': header}, fw, indent=4)
        if request_data.get('head', True):
            with open(output_data_file_path, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('{}\n'.format(','.join(header)) + content)
        # tar
        memory_file = io.BytesIO()
        tar = tarfile.open(fileobj=memory_file, mode='w:gz')
        tar.add(output_data_file_path,
                os.path.relpath(output_data_file_path, output_tmp_dir))
        tar.add(output_data_meta_file_path,
                os.path.relpath(output_data_meta_file_path, output_tmp_dir))
        tar.close()
        memory_file.seek(0)
        try:
            shutil.rmtree(os.path.dirname(output_data_file_path))
        except Exception as e:
            # warning
            stat_logger.warning(e)
        tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(
            request_data['job_id'], request_data['component_name'],
            request_data['role'], request_data['party_id'])
        return send_file(memory_file,
                         attachment_filename=tar_file_name,
                         as_attachment=True)
예제 #8
0
def get_mysql_info():
    if IS_STANDALONE:
        return error_response(404, 'mysql only available on cluster mode')

    try:
        with DB.connection_context():
            DB.random()
    except Exception as e:
        return error_response(503, str(e))

    return error_response(200)
예제 #9
0
def list_job():
    limit, offset = parse_limit_and_offset()

    query = {
        'tag': ('!=', 'submit_failed'),
    }

    for i in ('job_id', 'description'):
        if request.json.get(i) is not None:
            query[i] = ('contains', request.json[i])
    if request.json.get('party_id') is not None:
        try:
            query['party_id'] = int(request.json['party_id'])
        except Exception:
            return error_response(400, f"Invalid parameter 'party_id'.")
        query['party_id'] = ('contains', query['party_id'])
    if request.json.get('partner') is not None:
        query['roles'] = ('contains', query['partner'])

    for i in ('role', 'status'):
        if request.json.get(i) is None:
            continue

        if isinstance(request.json[i], str):
            request.json[i] = [request.json[i]]
        if not isinstance(request.json[i], list):
            return error_response(400, f"Invalid parameter '{i}'.")
        request.json[i] = set(request.json[i])

        for j in request.json[i]:
            if j not in valid_query_parameters[i]:
                return error_response(400, f"Invalid parameter '{i}'.")
        query[i] = ('in_', request.json[i])

    jobs, count = job_utils.list_job(limit, offset, query,
                                     parse_order_by(('create_time', 'desc')))
    jobs = [job.to_human_model_dict() for job in jobs]

    for job in jobs:
        job['party_id'] = int(job['party_id'])

        job['partners'] = set()
        for i in ('guest', 'host', 'arbiter'):
            job['partners'].update(job['roles'].get(i, []))
        job['partners'].discard(job['party_id'])
        job['partners'] = sorted(job['partners'])

    return get_json_result(data={
        'jobs': jobs,
        'count': count,
    })
예제 #10
0
def parse_order_by(default=None):
    order_by = []

    if request.json.get('order_by') is not None:
        if request.json['order_by'] not in valid_query_parameters['order_by']:
            abort(error_response(400, f"Invalid parameter 'order_by'."))
        order_by.append(request.json['order_by'])

        if request.json.get('order') is not None:
            if request.json['order'] not in valid_query_parameters['order']:
                abort(error_response(400, f"Invalid parameter order 'order'."))
            order_by.append(request.json['order'])

    return order_by or default
예제 #11
0
def get_eggroll_info():
    if IS_STANDALONE:
        return error_response(404, 'eggroll only available on cluster mode')

    if PROXY != CoordinationProxyService.ROLLSITE:
        return error_response(
            404, 'coordination communication protocol is not rollsite')

    conf = ServiceRegistry.FATE_ON_EGGROLL['rollsite']
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        r = s.connect_ex((conf['host'], conf['port']))
        if r != 0:
            return error_response(503)

    return error_response(200)
예제 #12
0
파일: job_app.py 프로젝트: zeta1999/FATE
def dsl_generator():
    data = request.json
    cpn_str = data.get("cpn_str", "")
    try:
        if not cpn_str:
            raise Exception("Component list should not be empty.")
        if isinstance(cpn_str, list):
            cpn_list = cpn_str
        else:
            if (cpn_str.find("/") and cpn_str.find("\\")) != -1:
                raise Exception(
                    "Component list string should not contain '/' or '\\'.")
            cpn_str = cpn_str.replace(" ", "").replace("\n", "").strip(",[]")
            cpn_list = cpn_str.split(",")
        train_dsl = json_loads(data.get("train_dsl"))
        parser = schedule_utils.get_dsl_parser_by_version(
            data.get("version", "2"))
        predict_dsl = parser.deploy_component(cpn_list, train_dsl)

        if data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY, data.get("filename"))
            with open(temp_filepath, "w") as fout:
                fout.write(json.dumps(predict_dsl, indent=4))
            return send_file(open(temp_filepath, 'rb'),
                             as_attachment=True,
                             attachment_filename=data.get("filename"))
        return get_json_result(data=predict_dsl)
    except Exception as e:
        stat_logger.exception(e)
        return error_response(
            210, "DSL generating failed. For more details, "
            "please check logs/fate_flow/fate_flow_stat.log.")
예제 #13
0
def parse_limit_and_offset():
    try:
        limit = int(request.json.get('limit', 0))
        page = int(request.json.get('page', 1)) - 1
    except Exception:
        abort(error_response(400, f"Invalid parameter 'limit' or 'page'."))

    return limit, limit * page
예제 #14
0
def get_fateboard_info():
    host = ServiceRegistry.FATEBOARD.get('host')
    port = ServiceRegistry.FATEBOARD.get('port')
    if not host or not port:
        return error_response(404, 'fateboard is not configured')
    return get_json_result(data={
        'host': host,
        'port': port,
    })
예제 #15
0
파일: info_app.py 프로젝트: kunchengit/FATE
def get_fateboard_info():
    fateboard = get_base_config('fateboard', {})
    host = fateboard.get('host')
    port = fateboard.get('port')
    if not host or not port:
        return error_response(404, 'fateboard is not configured')
    return get_json_result(data={
        'host': host,
        'port': port,
    })
예제 #16
0
def get_checkpoint():
    checkpoint_manager = load_checkpoints()

    if 'step_index' in request.json:
        try:
            request.json['step_index'] = int(request.json['step_index'])
        except Exception:
            return error_response(400, "Invalid 'step_index'")

        checkpoint = checkpoint_manager.get_checkpoint_by_index(
            request.json['step_index'])
    elif 'step_name' in request.json:
        checkpoint = checkpoint_manager.get_checkpoint_by_name(
            request.json['step_name'])
    else:
        return error_response(400, "'step_index' or 'step_name' is required")

    if checkpoint is None:
        return error_response(404, "The checkpoint was not found.")

    return get_json_result(data=checkpoint.to_dict(True))
예제 #17
0
def check_job_log_dir():
    job_id = str(request.json['job_id'])
    job_log_dir = job_utils.get_job_log_directory(job_id=job_id)

    if not os.path.exists(job_log_dir):
        abort(
            error_response(
                404,
                f"Log file path: '{job_log_dir}' not found. Please check if the job id is valid."
            ))

    return job_id, job_log_dir
예제 #18
0
def list_task():
    limit, offset = parse_limit_and_offset()

    query = {}
    for i in ('job_id', 'role', 'party_id', 'component_name'):
        if request.json.get(i) is not None:
            query[i] = request.json[i]
    if query.get('role') is not None:
        if query['role'] not in valid_query_parameters['role']:
            return error_response(400, f"Invalid parameter 'role'.")
    if query.get('party_id') is not None:
        try:
            query['party_id'] = int(query['party_id'])
        except Exception:
            return error_response(400, f"Invalid parameter 'party_id'.")

    tasks, count = job_utils.list_task(limit, offset, query,
                                       parse_order_by(('create_time', 'asc')))
    return get_json_result(
        data={
            'tasks': [task.to_human_model_dict() for task in tasks],
            'count': count,
        })
예제 #19
0
def authentication():
    if request.json and request.form:
        return error_response(400)

    if not (HTTP_APP_KEY and HTTP_SECRET_KEY):
        return

    for i in [
            'TIMESTAMP',
            'NONCE',
            'APP_KEY',
            'SIGNATURE',
    ]:
        if not request.headers.get(i):
            return error_response(401)

    try:
        timestamp = int(request.headers['TIMESTAMP']) / 1000
    except Exception:
        return error_response(400, 'Invalid TIMESTAMP')

    now = time()
    if not now - MAX_TIMESTAMP_INTERVAL < timestamp < now + MAX_TIMESTAMP_INTERVAL:
        return error_response(
            425,
            f'TIMESTAMP is more than {MAX_TIMESTAMP_INTERVAL} seconds away from the server time'
        )

    if not request.headers['NONCE']:
        return error_response(400, 'Invalid NONCE')

    if request.headers['APP_KEY'] != HTTP_APP_KEY:
        return error_response(401, 'Unknown APP_KEY')

    signature = b64encode(
        HMAC(
            HTTP_SECRET_KEY.encode('ascii'),
            b'\n'.join([
                request.headers['TIMESTAMP'].encode('ascii'),
                request.headers['NONCE'].encode('ascii'),
                request.headers['APP_KEY'].encode('ascii'),
                request.full_path.rstrip('?').encode('ascii'),
                request.data if request.json else b'',
                # quote_via: `urllib.parse.quote` replaces spaces with `%20`
                # safe: unreserved characters from rfc3986
                urlencode(
                    sorted(request.form.items()), quote_via=quote,
                    safe='-._~').encode('ascii') if request.form else b'',
            ]),
            'sha1').digest()).decode('ascii')
    if signature != request.headers['SIGNATURE']:
        return error_response(403)
예제 #20
0
def load_checkpoints():
    required_args = [
        'role', 'party_id', 'model_id', 'model_version', 'component_name'
    ]
    try:
        check_config(request.json, required_args)
    except Exception as e:
        abort(error_response(400, str(e)))

    checkpoint_manager = CheckpointManager(**{
        i: request.json[i]
        for i in required_args
    },
                                           mkdir=False)
    checkpoint_manager.load_checkpoints_from_disk()
    return checkpoint_manager
예제 #21
0
def table_download():
    request_data = request.json
    from fate_flow.component_env_utils.env_utils import import_component_output_depend
    import_component_output_depend()
    data_table_meta = storage.StorageTableMeta(
        name=request_data.get("name"), namespace=request_data.get("namespace"))
    if not data_table_meta:
        return error_response(
            response_code=210,
            retmsg=
            f'no found table:{request_data.get("namespace")}, {request_data.get("name")}'
        )
    tar_file_name = 'table_{}_{}.tar.gz'.format(request_data.get("namespace"),
                                                request_data.get("name"))
    return TableStorage.send_table(
        output_tables_meta={"table": data_table_meta},
        tar_file_name=tar_file_name,
        need_head=request_data.get("head", True))
예제 #22
0
파일: model_app.py 프로젝트: zeta1999/FATE
def get_predict_conf():
    request_data = request.json
    required_parameters = ['model_id', 'model_version']
    check_config(request_data, required_parameters)
    model_dir = os.path.join(get_project_base_directory(), 'model_local_cache')
    model_fp_list = glob.glob(
        model_dir +
        f"/guest#*#{request_data['model_id']}/{request_data['model_version']}")
    if model_fp_list:
        fp = model_fp_list[0]
        pipeline_model = PipelinedModel(model_id=fp.split('/')[-2],
                                        model_version=fp.split('/')[-1])
        pipeline = pipeline_model.read_component_model('pipeline',
                                                       'pipeline')['Pipeline']
        predict_dsl = json_loads(pipeline.inference_dsl)

        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        parser = schedule_utils.get_dsl_parser_by_version(
            train_runtime_conf.get('dsl_version', '1'))
        predict_conf = parser.generate_predict_conf_template(
            predict_dsl=predict_dsl,
            train_conf=train_runtime_conf,
            model_id=request_data['model_id'],
            model_version=request_data['model_version'])
    else:
        predict_conf = ''
    if predict_conf:
        if request_data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY,
                                         request_data.get("filename"))
            with open(temp_filepath, "w") as fout:

                fout.write(json_dumps(predict_conf, indent=4))
            return send_file(open(temp_filepath, "rb"),
                             as_attachment=True,
                             attachment_filename=request_data.get("filename"))
        else:
            return get_json_result(data=predict_conf)
    return error_response(
        210,
        "No model found, please check if arguments are specified correctly.")
예제 #23
0
파일: job_app.py 프로젝트: zeta1999/FATE
def job_log():
    job_id = request.json.get('job_id', '')
    job_log_dir = job_utils.get_job_log_directory(job_id=job_id)
    if os.path.exists(job_log_dir):
        memory_file = io.BytesIO()
        tar = tarfile.open(fileobj=memory_file, mode='w:gz')
        for root, dir, files in os.walk(job_log_dir):
            for file in files:
                full_path = os.path.join(root, file)
                rel_path = os.path.relpath(full_path, job_log_dir)
                tar.add(full_path, rel_path)
        tar.close()
        memory_file.seek(0)
        return send_file(
            memory_file,
            attachment_filename='job_{}_log.tar.gz'.format(job_id),
            as_attachment=True)
    else:
        return error_response(
            210,
            "Log file path: {} not found. Please check if the job id is valid."
            .format(job_log_dir))
예제 #24
0
def register():
    info = request.json or request.form.to_dict()
    if not Path(info["path"]).is_dir():
        return error_response(400, "invalid path")

    provider = ComponentProvider(
        name=info["name"],
        version=info["version"],
        path=info["path"],
        class_path=info.get("class_path",
                            ComponentRegistry.get_default_class_path()))
    code, std = WorkerManager.start_general_worker(
        worker_name=WorkerName.PROVIDER_REGISTRAR, provider=provider)
    if code == 0:
        ComponentRegistry.load()
        if ComponentRegistry.get_providers().get(provider.name, {}).get(
                provider.version, None) is None:
            return get_json_result(retcode=RetCode.OPERATING_ERROR,
                                   retmsg=f"not load into memory")
        else:
            return get_json_result()
    else:
        return get_json_result(retcode=RetCode.OPERATING_ERROR,
                               retmsg=f"register failed:\n{std}")
예제 #25
0
파일: model_app.py 프로젝트: zeta1999/FATE
def operate_model(model_operation):
    request_config = request.json or request.form.to_dict()
    job_id = job_utils.generate_job_id()
    if model_operation not in [
            ModelOperation.STORE, ModelOperation.RESTORE,
            ModelOperation.EXPORT, ModelOperation.IMPORT
    ]:
        raise Exception(
            'Can not support this operating now: {}'.format(model_operation))
    required_arguments = ["model_id", "model_version", "role", "party_id"]
    check_config(request_config, required_arguments=required_arguments)
    request_config["model_id"] = gen_party_model_id(
        model_id=request_config["model_id"],
        role=request_config["role"],
        party_id=request_config["party_id"])
    if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]:
        if model_operation == ModelOperation.IMPORT:
            try:
                file = request.files.get('file')
                file_path = os.path.join(TEMP_DIRECTORY, file.filename)
                # if not os.path.exists(file_path):
                #     raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
                #                     'please check the path: {}'.format(file_path))
                try:
                    os.makedirs(os.path.dirname(file_path), exist_ok=True)
                    file.save(file_path)
                except Exception as e:
                    shutil.rmtree(file_path)
                    raise e
                request_config['file'] = file_path
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                model.unpack_model(file_path)

                pipeline = model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']
                train_runtime_conf = json_loads(pipeline.train_runtime_conf)
                permitted_party_id = []
                for key, value in train_runtime_conf.get('role', {}).items():
                    for v in value:
                        permitted_party_id.extend([v, str(v)])
                if request_config["party_id"] not in permitted_party_id:
                    shutil.rmtree(model.model_path)
                    raise Exception(
                        "party id {} is not in model roles, please check if the party id is valid."
                    )
                try:
                    adapter = JobRuntimeConfigAdapter(train_runtime_conf)
                    job_parameters = adapter.get_common_parameters().to_dict()
                    with DB.connection_context():
                        db_model = MLModel.get_or_none(
                            MLModel.f_job_id == job_parameters.get(
                                "model_version"),
                            MLModel.f_role == request_config["role"])
                    if not db_model:
                        model_info = model_utils.gather_model_info_data(model)
                        model_info['imported'] = 1
                        model_info['job_id'] = model_info['f_model_version']
                        model_info['size'] = model.calculate_model_file_size()
                        model_info['role'] = request_config["model_id"].split(
                            '#')[0]
                        model_info['party_id'] = request_config[
                            "model_id"].split('#')[1]
                        if model_utils.compare_version(
                                model_info['f_fate_version'], '1.5.1') == 'lt':
                            model_info['roles'] = model_info.get(
                                'f_train_runtime_conf', {}).get('role', {})
                            model_info['initiator_role'] = model_info.get(
                                'f_train_runtime_conf',
                                {}).get('initiator', {}).get('role')
                            model_info['initiator_party_id'] = model_info.get(
                                'f_train_runtime_conf',
                                {}).get('initiator', {}).get('party_id')
                            model_info[
                                'work_mode'] = adapter.get_job_work_mode()
                            model_info['parent'] = False if model_info.get(
                                'f_inference_dsl') else True
                        model_utils.save_model_info(model_info)
                    else:
                        stat_logger.info(
                            f'job id: {job_parameters.get("model_version")}, '
                            f'role: {request_config["role"]} model info already existed in database.'
                        )
                except peewee.IntegrityError as e:
                    stat_logger.exception(e)
                operation_record(request_config, "import", "success")
                return get_json_result()
            except Exception:
                operation_record(request_config, "import", "failed")
                raise
        else:
            try:
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                if model.exists():
                    archive_file_path = model.packaging_model()
                    operation_record(request_config, "export", "success")
                    return send_file(archive_file_path,
                                     attachment_filename=os.path.basename(
                                         archive_file_path),
                                     as_attachment=True)
                else:
                    operation_record(request_config, "export", "failed")
                    res = error_response(
                        response_code=210,
                        retmsg="Model {} {} is not exist.".format(
                            request_config.get("model_id"),
                            request_config.get("model_version")))
                    return res
            except Exception as e:
                operation_record(request_config, "export", "failed")
                stat_logger.exception(e)
                return error_response(response_code=210, retmsg=str(e))
    else:
        data = {}
        job_dsl, job_runtime_conf = gen_model_operation_job_config(
            request_config, model_operation)
        submit_result = DAGScheduler.submit(
            {
                'job_dsl': job_dsl,
                'job_runtime_conf': job_runtime_conf
            },
            job_id=job_id)
        data.update(submit_result)
        operation_record(data=job_runtime_conf,
                         oper_type=model_operation,
                         oper_status='')
        return get_json_result(job_id=job_id, data=data)
예제 #26
0
파일: model_app.py 프로젝트: tarada/FATE
def operate_model(model_operation):
    request_config = request.json or request.form.to_dict()
    job_id = job_utils.generate_job_id()
    if model_operation not in [
            ModelOperation.STORE, ModelOperation.RESTORE,
            ModelOperation.EXPORT, ModelOperation.IMPORT
    ]:
        raise Exception(
            'Can not support this operating now: {}'.format(model_operation))
    required_arguments = ["model_id", "model_version", "role", "party_id"]
    check_config(request_config, required_arguments=required_arguments)
    request_config["model_id"] = gen_party_model_id(
        model_id=request_config["model_id"],
        role=request_config["role"],
        party_id=request_config["party_id"])
    if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]:
        if model_operation == ModelOperation.IMPORT:
            try:
                file = request.files.get('file')
                file_path = os.path.join(TEMP_DIRECTORY, file.filename)
                # if not os.path.exists(file_path):
                #     raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
                #                     'please check the path: {}'.format(file_path))
                try:
                    os.makedirs(os.path.dirname(file_path), exist_ok=True)
                    file.save(file_path)
                except Exception as e:
                    shutil.rmtree(file_path)
                    raise e
                request_config['file'] = file_path
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                model.unpack_model(file_path)

                pipeline = model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']
                train_runtime_conf = json_loads(pipeline.train_runtime_conf)
                permitted_party_id = []
                for key, value in train_runtime_conf.get('role', {}).items():
                    for v in value:
                        permitted_party_id.extend([v, str(v)])
                if request_config["party_id"] not in permitted_party_id:
                    shutil.rmtree(model.model_path)
                    raise Exception(
                        "party id {} is not in model roles, please check if the party id is valid."
                    )
                try:
                    with DB.connection_context():
                        model = MLModel.get_or_none(
                            MLModel.f_job_id == train_runtime_conf[
                                "job_parameters"]["model_version"],
                            MLModel.f_role == request_config["role"])
                        if not model:
                            MLModel.create(
                                f_role=request_config["role"],
                                f_party_id=request_config["party_id"],
                                f_roles=train_runtime_conf["role"],
                                f_job_id=train_runtime_conf["job_parameters"]
                                ["model_version"],
                                f_model_id=train_runtime_conf["job_parameters"]
                                ["model_id"],
                                f_model_version=train_runtime_conf[
                                    "job_parameters"]["model_version"],
                                f_initiator_role=train_runtime_conf[
                                    "initiator"]["role"],
                                f_initiator_party_id=train_runtime_conf[
                                    "initiator"]["party_id"],
                                f_runtime_conf=train_runtime_conf,
                                f_work_mode=train_runtime_conf[
                                    "job_parameters"]["work_mode"],
                                f_dsl=json_loads(pipeline.train_dsl),
                                f_imported=1,
                                f_job_status='complete')
                        else:
                            stat_logger.info(
                                f'job id: {train_runtime_conf["job_parameters"]["model_version"]}, '
                                f'role: {request_config["role"]} model info already existed in database.'
                            )
                except peewee.IntegrityError as e:
                    stat_logger.exception(e)
                operation_record(request_config, "import", "success")
                return get_json_result()
            except Exception:
                operation_record(request_config, "import", "failed")
                raise
        else:
            try:
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                if model.exists():
                    archive_file_path = model.packaging_model()
                    operation_record(request_config, "export", "success")
                    return send_file(archive_file_path,
                                     attachment_filename=os.path.basename(
                                         archive_file_path),
                                     as_attachment=True)
                else:
                    operation_record(request_config, "export", "failed")
                    res = error_response(
                        response_code=210,
                        retmsg="Model {} {} is not exist.".format(
                            request_config.get("model_id"),
                            request_config.get("model_version")))
                    return res
            except Exception as e:
                operation_record(request_config, "export", "failed")
                stat_logger.exception(e)
                return error_response(response_code=210, retmsg=str(e))
    else:
        data = {}
        job_dsl, job_runtime_conf = gen_model_operation_job_config(
            request_config, model_operation)
        job_id, job_dsl_path, job_runtime_conf_path, logs_directory, model_info, board_url = DAGScheduler.submit(
            {
                'job_dsl': job_dsl,
                'job_runtime_conf': job_runtime_conf
            },
            job_id=job_id)
        data.update({
            'job_dsl_path': job_dsl_path,
            'job_runtime_conf_path': job_runtime_conf_path,
            'board_url': board_url,
            'logs_directory': logs_directory
        })
        operation_record(data=job_runtime_conf,
                         oper_type=model_operation,
                         oper_status='')
        return get_json_result(job_id=job_id, data=data)
예제 #27
0
파일: info_app.py 프로젝트: kunchengit/FATE
def internal_server_error(e):
    stat_logger.exception(e)
    return error_response(500, str(e))
예제 #28
0
def download_model(party_model_id, model_version):
    party_model_id = party_model_id.replace('~', '#')
    model_data = publish_model.download_model(party_model_id, model_version)
    if model_data is None:
        return error_response(404, 'model not found')
    return get_json_result(data=model_data)
예제 #29
0
def component_output_data_download():
    request_data = request.json
    try:
        output_tables_meta = get_component_output_tables_meta(task_data=request_data)
    except Exception as e:
        stat_logger.exception(e)
        return error_response(210, str(e))
    limit = request_data.get('limit', -1)
    if not output_tables_meta:
        return error_response(response_code=210, retmsg='no data')
    if limit == 0:
        return error_response(response_code=210, retmsg='limit is 0')
    have_data_label = False

    output_data_file_list = []
    output_data_meta_file_list = []
    output_tmp_dir = os.path.join(os.getcwd(), 'tmp/{}'.format(fate_uuid()))
    for output_name, output_table_meta in output_tables_meta.items():
        output_data_count = 0
        is_str = False
        output_data_file_path = "{}/{}.csv".format(output_tmp_dir, output_name)
        os.makedirs(os.path.dirname(output_data_file_path), exist_ok=True)
        with open(output_data_file_path, 'w') as fw:
            with storage.Session.build(name=output_table_meta.get_name(), namespace=output_table_meta.get_namespace()) as storage_session:
                output_table = storage_session.get_table()
                for k, v in output_table.collect():
                    data_line, have_data_label, is_str = get_component_output_data_line(src_key=k, src_value=v)
                    fw.write('{}\n'.format(','.join(map(lambda x: str(x), data_line))))
                    output_data_count += 1
                    if output_data_count == limit:
                        break

        if output_data_count:
            # get meta
            output_data_file_list.append(output_data_file_path)
            header = get_component_output_data_schema(output_table_meta=output_table_meta, have_data_label=have_data_label, is_str=is_str)
            output_data_meta_file_path = "{}/{}.meta".format(output_tmp_dir, output_name)
            output_data_meta_file_list.append(output_data_meta_file_path)
            with open(output_data_meta_file_path, 'w') as fw:
                json.dump({'header': header}, fw, indent=4)
            if request_data.get('head', True) and header:
                with open(output_data_file_path, 'r+') as f:
                    content = f.read()
                    f.seek(0, 0)
                    f.write('{}\n'.format(','.join(header)) + content)
    # tar
    memory_file = io.BytesIO()
    tar = tarfile.open(fileobj=memory_file, mode='w:gz')
    for index in range(0, len(output_data_file_list)):
        tar.add(output_data_file_list[index], os.path.relpath(output_data_file_list[index], output_tmp_dir))
        tar.add(output_data_meta_file_list[index], os.path.relpath(output_data_meta_file_list[index], output_tmp_dir))
    tar.close()
    memory_file.seek(0)
    output_data_file_list.extend(output_data_meta_file_list)
    for path in output_data_file_list:
        try:
            shutil.rmtree(os.path.dirname(path))
        except Exception as e:
            # warning
            stat_logger.warning(e)
        tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(request_data['job_id'],
                                                                    request_data['component_name'],
                                                                    request_data['role'], request_data['party_id'])
        return send_file(memory_file, attachment_filename=tar_file_name, as_attachment=True)