Exemplo n.º 1
0
    def _handle_end_training(
        self, message: coordinator_pb2.EndTrainingRequest, participant_id: str
    ) -> coordinator_pb2.EndTrainingReply:
        """Handles a EndTraining request.
        Args:
            message (:class:`~.coordinator_pb2.EndTrainingRequest`): The request to handle.
            participant_id (:obj:`str`): The id of the participant making the request.
        Returns:
            :class:`~.coordinator_pb2.EndTrainingReply`: The reply to the participant.
        """

        # TODO: Ideally we want to know for which round the participant is
        # submitting the updates and raise an exception if it is the wrong
        # round.
        weights_proto, number_samples, metrics_proto = (
            message.weights,
            message.number_samples,
            message.metrics,
        )

        # record the request data
        weight_update: Tuple[List[ndarray], int] = (
            [proto_to_ndarray(pnda) for pnda in weights_proto],
            number_samples,
        )
        metrics: Dict[str, List[ndarray]] = {
            k: [proto_to_ndarray(v) for v in mv.metrics]
            for k, mv in metrics_proto.items()
        }
        self.round.add_updates(participant_id, weight_update, metrics)

        # The round is over. Run the aggregation
        if self.round.is_finished():
            logger.info(
                "Running aggregation for round", current_round=self.current_round
            )

            self.weights = self.aggregator.aggregate(self.round.get_weight_updates())

            # update the round or finish the training session
            if self.current_round == self.num_rounds:
                self.state = coordinator_pb2.State.FINISHED
            else:
                self.current_round += 1
                # reinitialize the round
                self.select_participant_ids_and_init_round()

        return coordinator_pb2.EndTrainingReply()
Exemplo n.º 2
0
    def SayHelloNumProto(self, request, context):
        nda = proto_to_ndarray(request.arr)
        logger.info("NumProto server received", nda=nda)

        nda *= 2
        logger.info("NumProto server sent", nda=nda)
        return hellonumproto_pb2.NumProtoReply(arr=ndarray_to_proto(nda))
Exemplo n.º 3
0
    def SayHelloNumProto(self, request, context):
        nda = proto_to_ndarray(request.arr)
        print("NumProto server received: {}".format(nda))

        nda *= 2
        print("NumProto server sent: {}".format(nda))
        return hellonumproto_pb2.NumProtoReply(arr=ndarray_to_proto(nda))
Exemplo n.º 4
0
def to_polymath_arg(arg, graph, write_graph, pb_node, verbose):
    if arg.type == pb.Attribute.Type.NODE:
        arg_str = arg.s.decode("utf-8")
        if arg_str in graph.nodes:
            arg_node = graph.nodes[arg_str]
        elif write_graph and arg_str in write_graph.nodes:
            arg_node = write_graph.nodes[arg_str]
        else:
            if verbose:
                err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \
                          f"Node name: {pb_node.name} - {pb_node.op_name}:\n" \
                          f"Keys: {list(graph.nodes.keys())}\n"
            else:
                err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \
                          f"Node name: {pb_node.name} - {pb_node.op_name}:\n"
            raise RuntimeError(err_str)
        return arg_node
    elif arg.type == pb.Attribute.Type.NDARRAY:
        return proto_to_ndarray(arg.nda)
    elif arg.type == pb.Attribute.Type.INT32:
        return arg.i32
    elif arg.type == pb.Attribute.Type.DOUBLE:
        return arg.d
    elif arg.type == pb.Attribute.Type.STRING:
        return arg.s.decode("utf-8")
    elif arg.type == pb.Attribute.Type.BOOL:
        return arg.b
    else:
        raise RuntimeError
def start_training(channel: Channel) -> Tuple[List[ndarray], int, int]:
    """Start a training initiation exchange with a coordinator.

    The decoded contents of the response from the coordinator are returned.

    Args:
        channel (~grpc.Channel): A gRPC channel to the coordinator.

    Returns:
        ~typing.List[~numpy.ndarray]: The weights of a global model to train on.
        int: The number of epochs to train.
        int: The epoch base of the global model.
    """

    coordinator: CoordinatorStub = CoordinatorStub(channel=channel)

    # send request to start training
    reply: StartTrainingReply = coordinator.StartTraining(
        request=StartTrainingRequest()
    )
    logger.info("Participant received reply", reply=type(reply))

    weights: List[ndarray] = [proto_to_ndarray(weight) for weight in reply.weights]
    epochs: int = reply.epochs
    epoch_base: int = reply.epoch_base

    return weights, epochs, epoch_base
Exemplo n.º 6
0
def start_training(channel) -> Tuple[Theta, int, int]:
    stub = coordinator_pb2_grpc.CoordinatorStub(channel)
    req = coordinator_pb2.StartTrainingRequest()
    # send request to start training
    reply = stub.StartTraining(req)
    print(f"Participant received: {type(reply)}")
    theta, epochs, epoch_base = reply.theta, reply.epochs, reply.epoch_base
    return [proto_to_ndarray(pnda) for pnda in theta], epochs, epoch_base
Exemplo n.º 7
0
def test_greeter_server(greeter_server):
    with grpc.insecure_channel("localhost:50051") as channel:
        stub = hellonumproto_pb2_grpc.NumProtoServerStub(channel)

        nda = np.arange(10)
        response = stub.SayHelloNumProto(
            hellonumproto_pb2.NumProtoRequest(arr=ndarray_to_proto(nda)))

        response_nda = proto_to_ndarray(response.arr)

        assert np.array_equal(nda * 2, response_nda)
Exemplo n.º 8
0
def run():
    with grpc.insecure_channel("localhost:50051") as channel:
        stub = hellonumproto_pb2_grpc.NumProtoServerStub(channel)

        nda = np.arange(10)
        print("NumProto client sent: {}".format(nda))

        response = stub.SayHelloNumProto(
            hellonumproto_pb2.NumProtoRequest(arr=ndarray_to_proto(nda))
        )
    print("NumProto client received: {}".format(proto_to_ndarray(response.arr)))
Exemplo n.º 9
0
 def EndTraining(self, request, context):
     print(f"Received: {type(request)} from {context.peer()}")
     tu, his, met = request.theta_update, request.history, request.metrics
     tp, num = tu.theta_prime, tu.num_examples
     cid, vbc = met.cid, met.vol_by_class
     # record the req data
     theta_update = [proto_to_ndarray(pnda) for pnda in tp], num
     self.theta_updates.append(theta_update)
     self.histories.append({k: list(hv.values) for k, hv in his.items()})
     self.metricss.append((cid, list(vbc)))
     # reply
     return coordinator_pb2.EndTrainingReply()
Exemplo n.º 10
0
def test_start_training():
    test_weights = [np.arange(10), np.arange(10, 20)]
    coordinator = Coordinator(
        minimum_participants_in_round=1,
        fraction_of_participants=1.0,
        weights=test_weights,
    )
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    result = coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                                    "participant1")
    received_weights = [proto_to_ndarray(nda) for nda in result.weights]

    np.testing.assert_equal(test_weights, received_weights)
Exemplo n.º 11
0
def proto_to_pointer(proto):
    """
    Convert a serialized NDArray to a C pointer

    Parameters
    ----------
    proto : proto.NDArray

    Returns:
        pointer :  ctypes.POINTER(ctypes.u_int8)
    """
    ndarray = proto_to_ndarray(proto)
    # FIXME make the ctype POINTER type configurable
    pointer = ndarray.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
    return pointer
Exemplo n.º 12
0
def start_training(channel) -> Tuple[Theta, int, int]:
    """Starts a training initiation exchange with Coordinator. Returns the decoded
    contents of the response from Coordinator.

    Args:
        channel: gRPC channel to Coordinator.

    Returns:
        obj:`Theta`: Global model to train on.
        obj:`int`: Number of epochs.
        obj:`int`: Epoch base.
    """
    stub = coordinator_pb2_grpc.CoordinatorStub(channel)
    req = coordinator_pb2.StartTrainingRequest()
    # send request to start training
    reply = stub.StartTraining(req)
    logger.info("Participant received", reply_type=type(reply))
    theta, epochs, epoch_base = reply.theta, reply.epochs, reply.epoch_base
    return [proto_to_ndarray(pnda) for pnda in theta], epochs, epoch_base
Exemplo n.º 13
0
def test_numproto(nda):
    """Tests serialization and deserialization of numpy arrays."""
    result = proto_to_ndarray(ndarray_to_proto(nda))
    assert np.array_equal(nda, result)
Exemplo n.º 14
0
Arquivo: core.py Projeto: Trawely/mc2
def attest():
    """
    Verify remote attestation report of enclave and extract its public key.
    The report and public key are saved as instance attributes.
    Parameters for attestation, e.g. whether to verify report,
    whether to check client list, whether to check MRSIGNER/MRENCLAVE, can be specified in config YAML.

    """
    pem_key = ctypes.POINTER(ctypes.c_uint8)()
    pem_key_size = ctypes.c_size_t()
    nonce = ctypes.POINTER(ctypes.c_uint8)()
    nonce_size = ctypes.c_size_t()
    client_list = ctypes.POINTER(ctypes.c_char_p)()
    client_list_size = ctypes.c_size_t()
    remote_report = ctypes.POINTER(ctypes.c_uint8)()
    remote_report_size = ctypes.c_size_t()

    channel_addr = _CONF["remote_addr"]

    if channel_addr is None:
        raise MC2ClientConfigError(
            "Remote orchestrator IP not set. Run oc.create_cluster() \
            to launch VMs and configure IPs automatically or explicitly set it in the user YAML."
        )

    with grpc.insecure_channel(channel_addr) as channel:
        stub = remote_pb2_grpc.RemoteStub(channel)
        response = stub.rpc_get_remote_report_with_pubkey_and_nonce(
            remote_pb2.Status(status=1))

    pem_key = proto_to_ndarray(response.pem_key).ctypes.data_as(
        ctypes.POINTER(ctypes.c_uint8))
    pem_key_size = ctypes.c_size_t(response.pem_key_size)
    nonce = proto_to_ndarray(response.nonce).ctypes.data_as(
        ctypes.POINTER(ctypes.c_uint8))
    nonce_size = ctypes.c_size_t(response.nonce_size)
    client_list = from_pystr_to_cstr(list(response.client_list))
    client_list_size = ctypes.c_size_t(response.client_list_size)

    remote_report = proto_to_ndarray(response.remote_report).ctypes.data_as(
        ctypes.POINTER(ctypes.c_uint8))
    remote_report_size = ctypes.c_size_t(response.remote_report_size)

    if _CONF.get("general_config") is None:
        raise MC2ClientConfigError("Configuration not set")

    # Load config to see what parameters user has specified
    attestation_config = yaml.safe_load(open(
        _CONF["general_config"]).read())["attestation"]
    simulation_mode = attestation_config.get("simulation_mode")
    check_client_list = attestation_config.get("check_client_list")

    mrenclave_hash = attestation_config.get("mrenclave")
    if mrenclave_hash and mrenclave_hash != "NULL":
        check_mrenclave = 1
        expected_mrenclave = c_str(mrenclave_hash)
        # TODO: should this be incremented?
        expected_mrenclave_len = len(mrenclave_hash) + 1
    else:
        check_mrenclave = 0
        expected_mrenclave = c_str("NULL")
        expected_mrenclave_len = 0

    mrsigner_public_key = attestation_config.get("mrsigner")
    if mrsigner_public_key and mrsigner_public_key != "NULL":
        check_mrsigner = 1
        expected_mrsigner = c_str(mrsigner_public_key)
        expected_mrsigner_len = len(mrsigner_public_key) + 1
    else:
        check_mrsigner = 0
        expected_mrsigner = c_str("NULL")
        expected_mrsigner_len = 0

    verification_passes = ctypes.c_int()

    # Verify attestation report
    if not simulation_mode:
        # Check public key, nonce, client list is in report hash
        _LIB.attest(
            pem_key,
            pem_key_size,
            nonce,
            nonce_size,
            from_pystr_to_cstr(attestation_config.get("client_list")),
            ctypes.c_size_t(len(attestation_config.get("client_list"))),
            remote_report,
            remote_report_size,
            check_mrenclave,
            expected_mrenclave,
            ctypes.c_size_t(expected_mrenclave_len),
            check_mrsigner,
            expected_mrsigner,
            ctypes.c_size_t(expected_mrsigner_len),
            ctypes.byref(verification_passes),
        )

        if not verification_passes.value:
            raise AttestationError(
                "Remote attestation report verification failed")

    # Verify client names match
    if simulation_mode and check_client_list:
        received_client_list = sorted(
            from_cstr_to_pystr(client_list, client_list_size))
        expected_client_list = sorted(attestation_config.get("client_list"))
        if received_client_list != expected_client_list:
            raise AttestationError(
                "Provided client list doesn't match that received from enclave"
            )

    # Set nonce, enclave public key, respective sizes
    _CONF["enclave_pk"] = pem_key
    _CONF["enclave_pk_size"] = pem_key_size
    _CONF["nonce"] = nonce
    _CONF["nonce_size"] = nonce_size
    _CONF["nonce_ctr"] = 0

    # Add client key to enclave
    # TODO: figure out how to do this for both Secure XGBoost and Opaque
    _add_client_key()
    _get_enclave_symm_key()
Exemplo n.º 15
0
def _deserialize_node(pb_node,
                      deserialization_info,
                      graph=None,
                      verbose=False):
    set_fields = pb_node.DESCRIPTOR.fields_by_name
    kwargs = {}
    kwargs["name"] = pb_node.name
    kwargs["op_name"] = pb_node.op_name
    kwargs["dependencies"] = [dep for dep in pb_node.dependencies]
    write_graph = None

    if kwargs["op_name"] == "write":
        wg = pb_node.kwargs["write_graph"]
        kwargs["write_graph"] = [a.decode("utf-8") for a in wg.ss]
        curr_g = graph
        while curr_g:
            if len(kwargs["write_graph"]
                   ) > 0 and kwargs["write_graph"][-1] in curr_g.nodes:
                write_graph = curr_g.nodes[kwargs["write_graph"][-1]]
                break
            curr_g = curr_g.graph

    args = []
    for name in pb_node.kwargs:
        arg = pb_node.kwargs[name]
        if arg.type == pb.Attribute.Type.DOM:
            kwargs[name] = _deserialize_domain(arg,
                                               graph,
                                               pb_node.name,
                                               deserialization_info,
                                               write_graph=write_graph)
        elif arg.type == pb.Attribute.Type.NODE:
            if arg.decode("utf-8") in graph.nodes:
                arg_node = graph.nodes[arg.decode("utf-8")]
            elif write_graph and arg.decode("utf-8") in write_graph.nodes:
                arg_node = write_graph.nodes[arg.decode("utf-8")]
            else:
                raise KeyError(
                    f"Unable to find node in graph {arg.decode('utf-8')}")
            kwargs[name] = arg_node
        elif arg.type == pb.Attribute.Type.NDARRAY:
            kwargs[name] = proto_to_ndarray(arg.nda)
        elif arg.type == pb.Attribute.Type.INT32:
            kwargs[name] = arg.i32
        elif arg.type == pb.Attribute.Type.DOUBLE:
            kwargs[name] = arg.d
        elif arg.type == pb.Attribute.Type.STRING:
            kwargs[name] = arg.s.decode("utf-8")
        elif arg.type == pb.Attribute.Type.BOOL:
            kwargs[name] = arg.b
        elif arg.type == pb.Attribute.Type.NODES:
            arg_node = []
            for a in arg.ss:
                if a.decode("utf-8") in graph.nodes:
                    anode = graph.nodes[a.decode("utf-8")]
                elif write_graph and a.decode("utf-8") in write_graph.nodes:
                    anode = write_graph.nodes[a.decode("utf-8")]
                else:
                    raise KeyError(
                        f"Unable to find node in graph {a.decode('utf-8')}")
                arg_node.append(anode)
            kwargs[name] = arg_node
        elif arg.type == pb.Attribute.Type.NDARRAYS:
            kwargs[name] = [proto_to_ndarray(a) for a in arg.ndas]
        elif arg.type == pb.Attribute.Type.INT32S:
            kwargs[name] = list(arg.i32s)
        elif arg.type == pb.Attribute.Type.DOUBLES:
            kwargs[name] = list(arg.ds)
        elif arg.type == pb.Attribute.Type.STRINGS:
            kwargs[name] = [a.decode("utf-8") for a in arg.ss]
        elif arg.type == pb.Attribute.Type.BOOLS:
            kwargs[name] = list(arg.b)
        else:
            raise TypeError(
                f"Cannot find deserializeable method for argument {name} with type {arg.type}"
            )

    for i, arg in enumerate(pb_node.args):
        if arg.type == pb.Attribute.Type.NODE:
            arg_str = arg.s.decode("utf-8")
            if arg_str in graph.nodes:
                arg_node = graph.nodes[arg_str]
            elif write_graph and arg_str in write_graph.nodes:
                arg_node = write_graph.nodes[arg_str]
            else:
                if verbose:
                    err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \
                              f"Node name: {pb_node.name} - {pb_node.op_name}:\n" \
                                   f"Keys: {list(graph.nodes.keys())}\n"
                else:
                    err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \
                              f"Node name: {pb_node.name} - {pb_node.op_name}:\n"
                raise RuntimeError(err_str)
            args.append(arg_node)
        elif arg.type == pb.Attribute.Type.NDARRAY:
            args.append(proto_to_ndarray(arg.nda))
        elif arg.type == pb.Attribute.Type.INT32:
            args.append(arg.i32)
        elif arg.type == pb.Attribute.Type.DOUBLE:
            args.append(arg.d)
        elif arg.type == pb.Attribute.Type.STRING:
            args.append(arg.s.decode("utf-8"))
        elif arg.type == pb.Attribute.Type.BOOL:
            args.append(arg.b)
        elif arg.type == pb.Attribute.Type.MAP:
            mapping = {}
            for name in arg.mapping:
                mapped_arg = arg.mapping[name]
                mapping[name] = to_polymath_arg(mapped_arg,
                                                graph,
                                                write_graph,
                                                pb_node,
                                                verbose=verbose)
            args.append(mapping)
        elif arg.type == pb.Attribute.Type.NODES:
            arg_node = []
            for a in arg.ss:
                if a.decode("utf-8") in graph.nodes:
                    anode = graph.nodes[a.decode("utf-8")]
                elif write_graph and a.decode("utf-8") in write_graph.nodes:
                    anode = write_graph.nodes[a.decode("utf-8")]
                else:
                    raise KeyError(
                        f"Unable to find node in graph {a.decode('utf-8')}")
                arg_node.append(anode)
            args.append(arg_node)
        elif arg.type == pb.Attribute.Type.NDARRAYS:
            args.append([proto_to_ndarray(a) for a in arg.ndas])
        elif arg.type == pb.Attribute.Type.INT32S:
            args.append(list(arg.i32s))
        elif arg.type == pb.Attribute.Type.DOUBLES:
            args.append(list(arg.ds))
        elif arg.type == pb.Attribute.Type.STRINGS:
            args.append([a.decode("utf-8") for a in arg.ss])
        elif arg.type == pb.Attribute.Type.BOOLS:
            args.append(list(arg.b))
        else:
            raise TypeError(
                f"Cannot find deserializeable method for argument {arg} with type {arg.type}"
            )

    args = tuple(args)

    mod_name, cls_name = pb_node.module.rsplit(".", 1)
    mod = __import__(mod_name, fromlist=[cls_name])

    if "target" in kwargs:
        func_mod_name, func_name = kwargs["target"].rsplit(".", 1)
        func_mod = __import__(func_mod_name, fromlist=[func_name])
        target = getattr(func_mod, func_name)
        kwargs.pop("target")

        if cls_name in ["func_op", "slice_op", "index_op"]:
            node = getattr(mod, cls_name)(target, *args, graph=graph, **kwargs)
        else:
            node = getattr(mod, cls_name)(*args, graph=graph, **kwargs)
    else:
        template_subclass_names = [
            c.__name__ for c in Template.__subclasses__()
        ]

        if cls_name in template_subclass_names:
            kwargs.pop("op_name")
            kwargs['skip_definition'] = True
            for a in args:
                if isinstance(
                        a, CONTEXT_TEMPLATE_TYPES
                ) and a.name not in deserialization_info['write_resets']:
                    a.write_count = 0
        node = getattr(mod, cls_name)(*args, graph=graph, **kwargs)
    if pb_node.graph_id >= 0 and pb_node.graph_id not in deserialization_info[
            'uuid_map']:
        print(f"Cannot find {pb_node.name} with graph {node.graph.name}")
    deserialization_info['uuid_map'][pb_node.uuid] = node

    for pb_n in pb_node.nodes:

        if pb_n.name in node.nodes:
            continue
        node.nodes[pb_n.name] = _deserialize_node(pb_n,
                                                  deserialization_info,
                                                  graph=node,
                                                  verbose=verbose)

    shape_list = []
    for shape in pb_node.shape:
        val_type = shape.WhichOneof("value")
        if val_type == "shape_const":
            shape_list.append(shape.shape_const)
        else:
            if shape.shape_id not in graph.nodes:
                shape_list.append(node.nodes[shape.shape_id])
            else:
                shape_list.append(graph.nodes[shape.shape_id])
    node._shape = tuple(shape_list)

    return node
Exemplo n.º 16
0
def _deserialize_domain(pb_dom, graph, node_name, info, write_graph=None):
    doms = []

    for d in pb_dom.dom.domains:
        if d.type == pb.Attribute.Type.NODE:
            if d.s.decode("utf-8") != node_name:
                d_name = d.s.decode("utf-8")
                if d_name in graph.nodes:
                    arg_node = graph.nodes[d_name]
                elif write_graph and d_name in write_graph.nodes:
                    arg_node = write_graph.nodes[d_name]
                else:
                    all_graphs = []
                    g = graph

                    while g is not None:
                        all_graphs.append(g.name)
                        g = g.graph
                    arg_node = None
                    for k, v in info['uuid_map'].items():
                        if d_name == v.name:
                            arg_node = v
                            break
                    if arg_node is None:
                        raise KeyError(
                            f"Unable to find node in graph {d_name} for {node_name}: "
                            f"{graph.name}. All Graphs: {all_graphs}\n"
                            f"Write graph: {write_graph}")
                doms.append(arg_node)

        elif d.type == pb.Attribute.Type.NDARRAY:
            doms.append(proto_to_ndarray(d.nda))
        elif d.type == pb.Attribute.Type.INT32:
            doms.append(d.i32)
        elif d.type == pb.Attribute.Type.DOUBLE:
            doms.append(d.d)
        elif d.type == pb.Attribute.Type.STRING:
            doms.append(d.s.decode("utf-8"))
        elif d.type == pb.Attribute.Type.BOOL:
            doms.append(d.b)
        elif d.type == pb.Attribute.Type.NODES:
            arg_node = []
            for a in d.ss:
                if a.decode("utf-8") in graph.nodes:
                    anode = graph.nodes[a.decode("utf-8")]
                elif write_graph and a.decode("utf-8") in write_graph.nodes:
                    anode = write_graph.nodes[a.decode("utf-8")]
                else:
                    raise KeyError(
                        f"Unable to find node in graph {a.decode('utf-8')}")
                arg_node.append(anode)
            doms.append(arg_node)
        elif d.type == pb.Attribute.Type.NDARRAYS:
            doms.append([proto_to_ndarray(a) for a in d.ndas])
        elif d.type == pb.Attribute.Type.INT32S:
            doms.append(list(d.i32s))
        elif d.type == pb.Attribute.Type.DOUBLES:
            doms.append(list(d.ds))
        elif d.type == pb.Attribute.Type.STRINGS:
            doms.append([a.decode("utf-8") for a in d.ss])
        elif d.type == pb.Attribute.Type.BOOLS:
            doms.append(list(d.b))
        else:
            raise TypeError(
                f"Cannot find deserializeable method for argument {d} with type {d.type}"
            )
    return Domain(tuple(doms))