Exemplo n.º 1
0
 def InitializeData(self, request, context):
     self.X = deserialize(request.x)
     self.y = deserialize(request.y)
     logging.info(
         msg=
         f"Client {self.device_index} initialized X {self.X.shape} and Y {self.y.shape}"
     )
     response = functions_pb2.Reply(
         str_reply=self.identifier,
         numeric_reply=self.X.shape[0])  # Returns the number of samples
     # logging.debug(self.identifier, "\n", self.X[0:5,0:5])
     return response
Exemplo n.º 2
0
    def Train(self, request, context):
        self._increase_global_counts()
        self.Xtheta = deserialize(request.model.xtheta)
        Q = int(request.q)
        lambduh = float(request.lambduh)
        logging.info(
            msg=
            f"Client {self.device_index} the dimension of Xtheta {self.Xtheta.shape}, Theta {self.theta.shape}"
        )
        # Isolate the H_-k from other datacenters for the same label space
        # Obtained in the last iteration
        Xtheta_from_other_DC = self.Xtheta - self.X @ self.theta  # Assuming label space is same
        for rounds in range(Q):
            logging.info("Starting local round ", rounds)
            # batch gradient descent for the time being

            # If NO partital gradient information from outside is used
            # grad = 1/len(device.X) * device.X.T @ (device.X @ device.theta - device.y)

            # If partital gradient information from outside is used
            grad = 1 / len(self.X) * self.X.T @ (
                (Xtheta_from_other_DC + self.X @ self.theta) -
                self.y) + lambduh * self.theta
            if self.decreasing_step:
                self.theta = self.theta - self.alpha / np.sqrt(
                    self.global_rounds_counter +
                    1) * grad  # decreasing dtep size
            else:
                self.theta = self.theta - self.alpha * grad
        self.Xtheta = self.X @ self.theta  # Update the value of the predicted y (probably unnecessary and not used)
        response_model = functions_pb2.Model(model=serialize(self.theta))
        return response_model
Exemplo n.º 3
0
def sendModel(client_id, stub, firstInitFlag, global_model):
    model = functions_pb2.Model(model=serialize(global_model))
    res = stub.UpdateLocalModels(model)
    xtheta = deserialize(res.xtheta)
    logging.info(msg=f"Client {res.id} sends back Xtheta {xtheta[0:10]} ")
    assert res.id == client_id
    return client_id, xtheta
Exemplo n.º 4
0
def getClientModels(client_id, stub):
    """
    Query the current model from a client
    """
    empty = functions_pb2.Empty(value=1)
    logging.info(f"Getting model from client {client_id}")
    res = stub.SendModel(empty)
    model = deserialize(res.model)
    print(f"Node {client_id}: Obtained model from client {client_id}")
    assert res.id == client_id
    return client_id, model
Exemplo n.º 5
0
 def InitializeParams(self, request, context):
     response = functions_pb2.Reply(str_reply=self.identifier)
     # print(f"Node: {request.index} received an initial string of type {type(request)} and value: {request}")
     self.alpha = request.alpha
     self.lambduh = request.lambduh
     self.device_index = request.device_index
     self.dc_index = request.dc_index
     self.theta = deserialize(request.model)
     self.decreasing_step = request.decreasing_step
     logging.info(
         msg=
         f"Client {self.device_index} Shape of the received array is the following: {self.theta.shape}"
     )
     logging.info(
         msg=
         f"Client {self.device_index} is initialized with identifier {self.identifier}"
     )
     return response
Exemplo n.º 6
0
def trainFunc(client_id, stub, q, lambduh, xtheta, model=None):
    trainconfig = functions_pb2.TrainConfig()
    trainconfig.q = q
    trainconfig.lambduh = lambduh
    trainconfig.model.xtheta = serialize(xtheta)
    # trainconfig.model.model = serialize(model)

    # if xtheta:
    #     trainconfig.model.xtheta = serialize(xtheta)
    # if model:
    #     trainconfig.model.model = serialize(model)
    # if not xtheta and not model:
    #     raise Exception
    res = stub.Train(trainconfig)
    model = deserialize(res.model)
    logging.info(
        f"Server received model from Client {client_id} ==> {model.shape}")
    return client_id, model
Exemplo n.º 7
0
 def UpdateLocalModels(self, request, context):
     self.theta = deserialize(request.model)
     self.Xtheta = self.X @ self.theta
     response = functions_pb2.Model(xtheta=serialize(self.Xtheta),
                                    id=self.identifier)
     return response