def _handle_end_training( self, message: coordinator_pb2.EndTrainingRequest, participant_id: str ) -> coordinator_pb2.EndTrainingReply: """Handles a EndTraining request. Args: message (:class:`~.coordinator_pb2.EndTrainingRequest`): The request to handle. participant_id (:obj:`str`): The id of the participant making the request. Returns: :class:`~.coordinator_pb2.EndTrainingReply`: The reply to the participant. """ # TODO: Ideally we want to know for which round the participant is # submitting the updates and raise an exception if it is the wrong # round. weights_proto, number_samples, metrics_proto = ( message.weights, message.number_samples, message.metrics, ) # record the request data weight_update: Tuple[List[ndarray], int] = ( [proto_to_ndarray(pnda) for pnda in weights_proto], number_samples, ) metrics: Dict[str, List[ndarray]] = { k: [proto_to_ndarray(v) for v in mv.metrics] for k, mv in metrics_proto.items() } self.round.add_updates(participant_id, weight_update, metrics) # The round is over. Run the aggregation if self.round.is_finished(): logger.info( "Running aggregation for round", current_round=self.current_round ) self.weights = self.aggregator.aggregate(self.round.get_weight_updates()) # update the round or finish the training session if self.current_round == self.num_rounds: self.state = coordinator_pb2.State.FINISHED else: self.current_round += 1 # reinitialize the round self.select_participant_ids_and_init_round() return coordinator_pb2.EndTrainingReply()
def SayHelloNumProto(self, request, context): nda = proto_to_ndarray(request.arr) logger.info("NumProto server received", nda=nda) nda *= 2 logger.info("NumProto server sent", nda=nda) return hellonumproto_pb2.NumProtoReply(arr=ndarray_to_proto(nda))
def SayHelloNumProto(self, request, context): nda = proto_to_ndarray(request.arr) print("NumProto server received: {}".format(nda)) nda *= 2 print("NumProto server sent: {}".format(nda)) return hellonumproto_pb2.NumProtoReply(arr=ndarray_to_proto(nda))
def to_polymath_arg(arg, graph, write_graph, pb_node, verbose): if arg.type == pb.Attribute.Type.NODE: arg_str = arg.s.decode("utf-8") if arg_str in graph.nodes: arg_node = graph.nodes[arg_str] elif write_graph and arg_str in write_graph.nodes: arg_node = write_graph.nodes[arg_str] else: if verbose: err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \ f"Node name: {pb_node.name} - {pb_node.op_name}:\n" \ f"Keys: {list(graph.nodes.keys())}\n" else: err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \ f"Node name: {pb_node.name} - {pb_node.op_name}:\n" raise RuntimeError(err_str) return arg_node elif arg.type == pb.Attribute.Type.NDARRAY: return proto_to_ndarray(arg.nda) elif arg.type == pb.Attribute.Type.INT32: return arg.i32 elif arg.type == pb.Attribute.Type.DOUBLE: return arg.d elif arg.type == pb.Attribute.Type.STRING: return arg.s.decode("utf-8") elif arg.type == pb.Attribute.Type.BOOL: return arg.b else: raise RuntimeError
def start_training(channel: Channel) -> Tuple[List[ndarray], int, int]: """Start a training initiation exchange with a coordinator. The decoded contents of the response from the coordinator are returned. Args: channel (~grpc.Channel): A gRPC channel to the coordinator. Returns: ~typing.List[~numpy.ndarray]: The weights of a global model to train on. int: The number of epochs to train. int: The epoch base of the global model. """ coordinator: CoordinatorStub = CoordinatorStub(channel=channel) # send request to start training reply: StartTrainingReply = coordinator.StartTraining( request=StartTrainingRequest() ) logger.info("Participant received reply", reply=type(reply)) weights: List[ndarray] = [proto_to_ndarray(weight) for weight in reply.weights] epochs: int = reply.epochs epoch_base: int = reply.epoch_base return weights, epochs, epoch_base
def start_training(channel) -> Tuple[Theta, int, int]: stub = coordinator_pb2_grpc.CoordinatorStub(channel) req = coordinator_pb2.StartTrainingRequest() # send request to start training reply = stub.StartTraining(req) print(f"Participant received: {type(reply)}") theta, epochs, epoch_base = reply.theta, reply.epochs, reply.epoch_base return [proto_to_ndarray(pnda) for pnda in theta], epochs, epoch_base
def test_greeter_server(greeter_server): with grpc.insecure_channel("localhost:50051") as channel: stub = hellonumproto_pb2_grpc.NumProtoServerStub(channel) nda = np.arange(10) response = stub.SayHelloNumProto( hellonumproto_pb2.NumProtoRequest(arr=ndarray_to_proto(nda))) response_nda = proto_to_ndarray(response.arr) assert np.array_equal(nda * 2, response_nda)
def run(): with grpc.insecure_channel("localhost:50051") as channel: stub = hellonumproto_pb2_grpc.NumProtoServerStub(channel) nda = np.arange(10) print("NumProto client sent: {}".format(nda)) response = stub.SayHelloNumProto( hellonumproto_pb2.NumProtoRequest(arr=ndarray_to_proto(nda)) ) print("NumProto client received: {}".format(proto_to_ndarray(response.arr)))
def EndTraining(self, request, context): print(f"Received: {type(request)} from {context.peer()}") tu, his, met = request.theta_update, request.history, request.metrics tp, num = tu.theta_prime, tu.num_examples cid, vbc = met.cid, met.vol_by_class # record the req data theta_update = [proto_to_ndarray(pnda) for pnda in tp], num self.theta_updates.append(theta_update) self.histories.append({k: list(hv.values) for k, hv in his.items()}) self.metricss.append((cid, list(vbc))) # reply return coordinator_pb2.EndTrainingReply()
def test_start_training(): test_weights = [np.arange(10), np.arange(10, 20)] coordinator = Coordinator( minimum_participants_in_round=1, fraction_of_participants=1.0, weights=test_weights, ) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") result = coordinator.on_message(coordinator_pb2.StartTrainingRequest(), "participant1") received_weights = [proto_to_ndarray(nda) for nda in result.weights] np.testing.assert_equal(test_weights, received_weights)
def proto_to_pointer(proto): """ Convert a serialized NDArray to a C pointer Parameters ---------- proto : proto.NDArray Returns: pointer : ctypes.POINTER(ctypes.u_int8) """ ndarray = proto_to_ndarray(proto) # FIXME make the ctype POINTER type configurable pointer = ndarray.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) return pointer
def start_training(channel) -> Tuple[Theta, int, int]: """Starts a training initiation exchange with Coordinator. Returns the decoded contents of the response from Coordinator. Args: channel: gRPC channel to Coordinator. Returns: obj:`Theta`: Global model to train on. obj:`int`: Number of epochs. obj:`int`: Epoch base. """ stub = coordinator_pb2_grpc.CoordinatorStub(channel) req = coordinator_pb2.StartTrainingRequest() # send request to start training reply = stub.StartTraining(req) logger.info("Participant received", reply_type=type(reply)) theta, epochs, epoch_base = reply.theta, reply.epochs, reply.epoch_base return [proto_to_ndarray(pnda) for pnda in theta], epochs, epoch_base
def test_numproto(nda): """Tests serialization and deserialization of numpy arrays.""" result = proto_to_ndarray(ndarray_to_proto(nda)) assert np.array_equal(nda, result)
def attest(): """ Verify remote attestation report of enclave and extract its public key. The report and public key are saved as instance attributes. Parameters for attestation, e.g. whether to verify report, whether to check client list, whether to check MRSIGNER/MRENCLAVE, can be specified in config YAML. """ pem_key = ctypes.POINTER(ctypes.c_uint8)() pem_key_size = ctypes.c_size_t() nonce = ctypes.POINTER(ctypes.c_uint8)() nonce_size = ctypes.c_size_t() client_list = ctypes.POINTER(ctypes.c_char_p)() client_list_size = ctypes.c_size_t() remote_report = ctypes.POINTER(ctypes.c_uint8)() remote_report_size = ctypes.c_size_t() channel_addr = _CONF["remote_addr"] if channel_addr is None: raise MC2ClientConfigError( "Remote orchestrator IP not set. Run oc.create_cluster() \ to launch VMs and configure IPs automatically or explicitly set it in the user YAML." ) with grpc.insecure_channel(channel_addr) as channel: stub = remote_pb2_grpc.RemoteStub(channel) response = stub.rpc_get_remote_report_with_pubkey_and_nonce( remote_pb2.Status(status=1)) pem_key = proto_to_ndarray(response.pem_key).ctypes.data_as( ctypes.POINTER(ctypes.c_uint8)) pem_key_size = ctypes.c_size_t(response.pem_key_size) nonce = proto_to_ndarray(response.nonce).ctypes.data_as( ctypes.POINTER(ctypes.c_uint8)) nonce_size = ctypes.c_size_t(response.nonce_size) client_list = from_pystr_to_cstr(list(response.client_list)) client_list_size = ctypes.c_size_t(response.client_list_size) remote_report = proto_to_ndarray(response.remote_report).ctypes.data_as( ctypes.POINTER(ctypes.c_uint8)) remote_report_size = ctypes.c_size_t(response.remote_report_size) if _CONF.get("general_config") is None: raise MC2ClientConfigError("Configuration not set") # Load config to see what parameters user has specified attestation_config = yaml.safe_load(open( _CONF["general_config"]).read())["attestation"] simulation_mode = attestation_config.get("simulation_mode") check_client_list = attestation_config.get("check_client_list") mrenclave_hash = attestation_config.get("mrenclave") if mrenclave_hash and mrenclave_hash != "NULL": check_mrenclave = 1 expected_mrenclave = c_str(mrenclave_hash) # TODO: should this be incremented? expected_mrenclave_len = len(mrenclave_hash) + 1 else: check_mrenclave = 0 expected_mrenclave = c_str("NULL") expected_mrenclave_len = 0 mrsigner_public_key = attestation_config.get("mrsigner") if mrsigner_public_key and mrsigner_public_key != "NULL": check_mrsigner = 1 expected_mrsigner = c_str(mrsigner_public_key) expected_mrsigner_len = len(mrsigner_public_key) + 1 else: check_mrsigner = 0 expected_mrsigner = c_str("NULL") expected_mrsigner_len = 0 verification_passes = ctypes.c_int() # Verify attestation report if not simulation_mode: # Check public key, nonce, client list is in report hash _LIB.attest( pem_key, pem_key_size, nonce, nonce_size, from_pystr_to_cstr(attestation_config.get("client_list")), ctypes.c_size_t(len(attestation_config.get("client_list"))), remote_report, remote_report_size, check_mrenclave, expected_mrenclave, ctypes.c_size_t(expected_mrenclave_len), check_mrsigner, expected_mrsigner, ctypes.c_size_t(expected_mrsigner_len), ctypes.byref(verification_passes), ) if not verification_passes.value: raise AttestationError( "Remote attestation report verification failed") # Verify client names match if simulation_mode and check_client_list: received_client_list = sorted( from_cstr_to_pystr(client_list, client_list_size)) expected_client_list = sorted(attestation_config.get("client_list")) if received_client_list != expected_client_list: raise AttestationError( "Provided client list doesn't match that received from enclave" ) # Set nonce, enclave public key, respective sizes _CONF["enclave_pk"] = pem_key _CONF["enclave_pk_size"] = pem_key_size _CONF["nonce"] = nonce _CONF["nonce_size"] = nonce_size _CONF["nonce_ctr"] = 0 # Add client key to enclave # TODO: figure out how to do this for both Secure XGBoost and Opaque _add_client_key() _get_enclave_symm_key()
def _deserialize_node(pb_node, deserialization_info, graph=None, verbose=False): set_fields = pb_node.DESCRIPTOR.fields_by_name kwargs = {} kwargs["name"] = pb_node.name kwargs["op_name"] = pb_node.op_name kwargs["dependencies"] = [dep for dep in pb_node.dependencies] write_graph = None if kwargs["op_name"] == "write": wg = pb_node.kwargs["write_graph"] kwargs["write_graph"] = [a.decode("utf-8") for a in wg.ss] curr_g = graph while curr_g: if len(kwargs["write_graph"] ) > 0 and kwargs["write_graph"][-1] in curr_g.nodes: write_graph = curr_g.nodes[kwargs["write_graph"][-1]] break curr_g = curr_g.graph args = [] for name in pb_node.kwargs: arg = pb_node.kwargs[name] if arg.type == pb.Attribute.Type.DOM: kwargs[name] = _deserialize_domain(arg, graph, pb_node.name, deserialization_info, write_graph=write_graph) elif arg.type == pb.Attribute.Type.NODE: if arg.decode("utf-8") in graph.nodes: arg_node = graph.nodes[arg.decode("utf-8")] elif write_graph and arg.decode("utf-8") in write_graph.nodes: arg_node = write_graph.nodes[arg.decode("utf-8")] else: raise KeyError( f"Unable to find node in graph {arg.decode('utf-8')}") kwargs[name] = arg_node elif arg.type == pb.Attribute.Type.NDARRAY: kwargs[name] = proto_to_ndarray(arg.nda) elif arg.type == pb.Attribute.Type.INT32: kwargs[name] = arg.i32 elif arg.type == pb.Attribute.Type.DOUBLE: kwargs[name] = arg.d elif arg.type == pb.Attribute.Type.STRING: kwargs[name] = arg.s.decode("utf-8") elif arg.type == pb.Attribute.Type.BOOL: kwargs[name] = arg.b elif arg.type == pb.Attribute.Type.NODES: arg_node = [] for a in arg.ss: if a.decode("utf-8") in graph.nodes: anode = graph.nodes[a.decode("utf-8")] elif write_graph and a.decode("utf-8") in write_graph.nodes: anode = write_graph.nodes[a.decode("utf-8")] else: raise KeyError( f"Unable to find node in graph {a.decode('utf-8')}") arg_node.append(anode) kwargs[name] = arg_node elif arg.type == pb.Attribute.Type.NDARRAYS: kwargs[name] = [proto_to_ndarray(a) for a in arg.ndas] elif arg.type == pb.Attribute.Type.INT32S: kwargs[name] = list(arg.i32s) elif arg.type == pb.Attribute.Type.DOUBLES: kwargs[name] = list(arg.ds) elif arg.type == pb.Attribute.Type.STRINGS: kwargs[name] = [a.decode("utf-8") for a in arg.ss] elif arg.type == pb.Attribute.Type.BOOLS: kwargs[name] = list(arg.b) else: raise TypeError( f"Cannot find deserializeable method for argument {name} with type {arg.type}" ) for i, arg in enumerate(pb_node.args): if arg.type == pb.Attribute.Type.NODE: arg_str = arg.s.decode("utf-8") if arg_str in graph.nodes: arg_node = graph.nodes[arg_str] elif write_graph and arg_str in write_graph.nodes: arg_node = write_graph.nodes[arg_str] else: if verbose: err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \ f"Node name: {pb_node.name} - {pb_node.op_name}:\n" \ f"Keys: {list(graph.nodes.keys())}\n" else: err_str = f"Could not find {arg_str} in nodes for {graph.name} - {graph}\n" \ f"Node name: {pb_node.name} - {pb_node.op_name}:\n" raise RuntimeError(err_str) args.append(arg_node) elif arg.type == pb.Attribute.Type.NDARRAY: args.append(proto_to_ndarray(arg.nda)) elif arg.type == pb.Attribute.Type.INT32: args.append(arg.i32) elif arg.type == pb.Attribute.Type.DOUBLE: args.append(arg.d) elif arg.type == pb.Attribute.Type.STRING: args.append(arg.s.decode("utf-8")) elif arg.type == pb.Attribute.Type.BOOL: args.append(arg.b) elif arg.type == pb.Attribute.Type.MAP: mapping = {} for name in arg.mapping: mapped_arg = arg.mapping[name] mapping[name] = to_polymath_arg(mapped_arg, graph, write_graph, pb_node, verbose=verbose) args.append(mapping) elif arg.type == pb.Attribute.Type.NODES: arg_node = [] for a in arg.ss: if a.decode("utf-8") in graph.nodes: anode = graph.nodes[a.decode("utf-8")] elif write_graph and a.decode("utf-8") in write_graph.nodes: anode = write_graph.nodes[a.decode("utf-8")] else: raise KeyError( f"Unable to find node in graph {a.decode('utf-8')}") arg_node.append(anode) args.append(arg_node) elif arg.type == pb.Attribute.Type.NDARRAYS: args.append([proto_to_ndarray(a) for a in arg.ndas]) elif arg.type == pb.Attribute.Type.INT32S: args.append(list(arg.i32s)) elif arg.type == pb.Attribute.Type.DOUBLES: args.append(list(arg.ds)) elif arg.type == pb.Attribute.Type.STRINGS: args.append([a.decode("utf-8") for a in arg.ss]) elif arg.type == pb.Attribute.Type.BOOLS: args.append(list(arg.b)) else: raise TypeError( f"Cannot find deserializeable method for argument {arg} with type {arg.type}" ) args = tuple(args) mod_name, cls_name = pb_node.module.rsplit(".", 1) mod = __import__(mod_name, fromlist=[cls_name]) if "target" in kwargs: func_mod_name, func_name = kwargs["target"].rsplit(".", 1) func_mod = __import__(func_mod_name, fromlist=[func_name]) target = getattr(func_mod, func_name) kwargs.pop("target") if cls_name in ["func_op", "slice_op", "index_op"]: node = getattr(mod, cls_name)(target, *args, graph=graph, **kwargs) else: node = getattr(mod, cls_name)(*args, graph=graph, **kwargs) else: template_subclass_names = [ c.__name__ for c in Template.__subclasses__() ] if cls_name in template_subclass_names: kwargs.pop("op_name") kwargs['skip_definition'] = True for a in args: if isinstance( a, CONTEXT_TEMPLATE_TYPES ) and a.name not in deserialization_info['write_resets']: a.write_count = 0 node = getattr(mod, cls_name)(*args, graph=graph, **kwargs) if pb_node.graph_id >= 0 and pb_node.graph_id not in deserialization_info[ 'uuid_map']: print(f"Cannot find {pb_node.name} with graph {node.graph.name}") deserialization_info['uuid_map'][pb_node.uuid] = node for pb_n in pb_node.nodes: if pb_n.name in node.nodes: continue node.nodes[pb_n.name] = _deserialize_node(pb_n, deserialization_info, graph=node, verbose=verbose) shape_list = [] for shape in pb_node.shape: val_type = shape.WhichOneof("value") if val_type == "shape_const": shape_list.append(shape.shape_const) else: if shape.shape_id not in graph.nodes: shape_list.append(node.nodes[shape.shape_id]) else: shape_list.append(graph.nodes[shape.shape_id]) node._shape = tuple(shape_list) return node
def _deserialize_domain(pb_dom, graph, node_name, info, write_graph=None): doms = [] for d in pb_dom.dom.domains: if d.type == pb.Attribute.Type.NODE: if d.s.decode("utf-8") != node_name: d_name = d.s.decode("utf-8") if d_name in graph.nodes: arg_node = graph.nodes[d_name] elif write_graph and d_name in write_graph.nodes: arg_node = write_graph.nodes[d_name] else: all_graphs = [] g = graph while g is not None: all_graphs.append(g.name) g = g.graph arg_node = None for k, v in info['uuid_map'].items(): if d_name == v.name: arg_node = v break if arg_node is None: raise KeyError( f"Unable to find node in graph {d_name} for {node_name}: " f"{graph.name}. All Graphs: {all_graphs}\n" f"Write graph: {write_graph}") doms.append(arg_node) elif d.type == pb.Attribute.Type.NDARRAY: doms.append(proto_to_ndarray(d.nda)) elif d.type == pb.Attribute.Type.INT32: doms.append(d.i32) elif d.type == pb.Attribute.Type.DOUBLE: doms.append(d.d) elif d.type == pb.Attribute.Type.STRING: doms.append(d.s.decode("utf-8")) elif d.type == pb.Attribute.Type.BOOL: doms.append(d.b) elif d.type == pb.Attribute.Type.NODES: arg_node = [] for a in d.ss: if a.decode("utf-8") in graph.nodes: anode = graph.nodes[a.decode("utf-8")] elif write_graph and a.decode("utf-8") in write_graph.nodes: anode = write_graph.nodes[a.decode("utf-8")] else: raise KeyError( f"Unable to find node in graph {a.decode('utf-8')}") arg_node.append(anode) doms.append(arg_node) elif d.type == pb.Attribute.Type.NDARRAYS: doms.append([proto_to_ndarray(a) for a in d.ndas]) elif d.type == pb.Attribute.Type.INT32S: doms.append(list(d.i32s)) elif d.type == pb.Attribute.Type.DOUBLES: doms.append(list(d.ds)) elif d.type == pb.Attribute.Type.STRINGS: doms.append([a.decode("utf-8") for a in d.ss]) elif d.type == pb.Attribute.Type.BOOLS: doms.append(list(d.b)) else: raise TypeError( f"Cannot find deserializeable method for argument {d} with type {d.type}" ) return Domain(tuple(doms))