Exemple #1
0
 def get_model(grid_address: str, worker_id: str, request_key: str,
               model_id: int) -> List:
     req = requests.get((
         f"http://{grid_address}/model-centric/get-model?worker_id={worker_id}&"
         f"request_key={request_key}&model_id={model_id}"))
     # TODO migrate to syft-core protobufs
     return deserialize_model_params(req.content)
Exemple #2
0
def get_model_params(grid_address: str, worker_id: str, request_key: str,
                     model_id: str) -> list[T.Tensor]:
    get_params = {
        "worker_id": worker_id,
        "request_key": request_key,
        "model_id": model_id,
    }
    response = requests.get(f"http://{grid_address}/model-centric/get-model",
                            get_params)
    return deserialize_model_params(response.content)
Exemple #3
0
def retrieve_model_params(grid_address: str, name: str,
                          version: str) -> list[T.Tensor]:
    get_params = {
        "name": name,
        "version": version,
        "checkpoint": "latest",
    }

    response = requests.get(
        f"http://{grid_address}/model-centric/retrieve-model", get_params)
    return deserialize_model_params(response.content)
Exemple #4
0
 def unserialize_model_params(bin: bytes):
     """Unserializes model or checkpoint or diff stored in db to list of
     tensors."""
     params = deserialize_model_params(bin)
     return params
def sanity_check_hosted_plan(
    name: str,
    model: SyModule,
    plan_inputs: OrderedDict,
    plan_output_params_idx: TypeList[int],
    plan_type: str = "list",
) -> TypeList[TypeTuple[type, th.Size]]:
    grid_address = f"localhost:{DOMAIN_PORT}"
    # Authenticate for cycle

    # Helper function to make WS requests
    def sendWsMessage(data: JSONDict) -> JSONDict:
        ws = create_connection("ws://" + grid_address)
        ws.send(json.dumps(data))
        message = ws.recv()
        return json.loads(message)

    auth_request = {
        "type": "model-centric/authenticate",
        "data": {
            "model_name": name,
            "model_version": "1.0",
            "auth_token": auth_token,
        },
    }

    auth_response = sendWsMessage(auth_request)

    # Do cycle request
    cycle_request = {
        "type": "model-centric/cycle-request",
        "data": {
            "worker_id": auth_response["data"]["worker_id"],
            "model": name,
            "version": "1.0",
            "ping": 1,
            "download": 10000,
            "upload": 10000,
        },
    }
    cycle_response = sendWsMessage(cycle_request)
    # Download model

    worker_id = auth_response["data"]["worker_id"]
    request_key = cycle_response["data"]["request_key"]
    model_id = cycle_response["data"]["model_id"]
    training_plan_id = cycle_response["data"]["plans"]["training_plan"]

    def get_model(
        grid_address: str, worker_id: str, request_key: str, model_id: int
    ) -> List:
        req = requests.get(
            (
                f"http://{grid_address}/model-centric/get-model?worker_id={worker_id}&"
                f"request_key={request_key}&model_id={model_id}"
            )
        )
        # TODO migrate to syft-core protobufs
        return deserialize_model_params(req.content)

    # Model
    model_params_downloaded = get_model(grid_address, worker_id, request_key, model_id)

    def get_plan(
        grid_address: str,
        worker_id: int,
        request_key: str,
        plan_id: int,
        plan_type: str,
    ) -> TypeUnion[PlanTorchscript, Plan]:
        req = requests.get(
            (
                f"http://{grid_address}/model-centric/get-plan?worker_id={worker_id}&"
                f"request_key={request_key}&plan_id={plan_id}&receive_operations_as={plan_type}"
            )
        )

        if plan_type == "torchscript":
            pb = PlanTorchscriptPB()
            pb.ParseFromString(req.content)
            return PlanTorchscript._proto2object(pb)
        else:
            pb = PlanPB()
            pb.ParseFromString(req.content)
            return deserialize(pb)

    # Download & Execute Plan
    plan = get_plan(grid_address, worker_id, request_key, training_plan_id, plan_type)
    plan_inputs["params"] = [
        th.nn.Parameter(param) for param in model_params_downloaded
    ]

    if plan_type == "torchscript":
        # kwargs are not supported in torchscript plan
        res = plan(*plan_inputs.values())
    else:
        res = plan(**plan_inputs)

    updated_params = [res[idx] for idx in plan_output_params_idx]

    # Report Model diff
    diff = [orig - new for new, orig in zip(updated_params, model.parameters())]
    diff_serialized = serialize(wrap_model_params(diff)).SerializeToString()

    params = {
        "type": "model-centric/report",
        "data": {
            "worker_id": worker_id,
            "request_key": request_key,
            "diff": base64.b64encode(diff_serialized).decode("ascii"),
        },
    }

    sendWsMessage(params)

    # Check new model
    req_params = {
        "name": name,
        "version": "1.0",
        "checkpoint": "latest",
    }

    res = requests.get(
        f"http://{grid_address}/model-centric/retrieve-model", req_params
    )

    new_model_params = deserialize_model_params(res.content)
    param_type_size = [(type(v), v.shape) for v in new_model_params]

    return param_type_size