示例#1
0
class DC_QSGDServer(Codec):
    codex = quant(
        build_quantization_space(
            int(GlobalSettings.get_params(Quantization_Resolution_Server))))

    def __init__(self, node_id):
        super().__init__(node_id)
        self.__global_weights: np.ndarray = 0
        self.__weights_states: Dict[int, np.ndarray] = {}

    def dispose(self):
        pass

    def update_blocks(self, block_weight: BlockWeight):
        pass

    def receive_blocks(self, package: IPack):
        if isinstance(package, QuantizedPack):
            self.__global_weights -= package.decode(DC_QSGDClient.codex)
            self.__weights_states[package.node_id] = self.__global_weights
        elif isinstance(package, SignalPack):
            # returns Q(w_t+\tau - w_t)
            if not isinstance(self.__global_weights, np.ndarray):
                reply = SignalPack(Parameter_Server)
            else:
                reply = QuantizedPack(
                    Parameter_Server, self.__global_weights -
                    self.__weights_states.get(package.node_id, 0),
                    DC_QSGDServer.codex)
            return netEncapsulation(package.node_id, reply)
示例#2
0
class QuantizedParaServer(Codec):
    q_space = build_quantization_space(
        int(GlobalSettings.get_params(Quantization_Resolution_Server)))

    def __init__(self, node_id):
        super().__init__(node_id)
        self.__global_weights: np.ndarray = 0
        self.__vt: np.ndarray = 0
        self.__beta: float = 0.9

    def dispose(self):
        pass

    def update_blocks(self, block_weight: BlockWeight):
        pass

    def receive_blocks(self, package: IPack):
        grad = package.content
        self.__global_weights -= grad
        self.__vt = self.__beta * self.__vt + np.square(grad) * (1 -
                                                                 self.__beta)
        reply = QuantizedPack(
            Parameter_Server,
            *Q_w(self.__global_weights, QuantizedParaServer.q_space,
                 np.sqrt(self.__vt)))
        return netEncapsulation(package.node_id, reply)
示例#3
0
class DC_QSGDClient(Codec):
    codex = quant(
        build_quantization_space(
            int(GlobalSettings.get_params(Quantization_Resolution_Client))))

    def __init__(self, node_id):
        super().__init__(node_id)
        self.__gw_t: np.ndarray = 0
        self.__lambda = 0.3

    def dispose(self):
        pass

    def update_blocks(self,
                      block_weight: BlockWeight) -> netEncapsulation[IPack]:
        self.__gw_t = block_weight.content
        return netEncapsulation(Parameter_Server, SignalPack(self.node_id))

    def receive_blocks(self, package: IPack) -> netEncapsulation[IPack]:
        delta_w = package.decode(DC_QSGDServer.codex)
        ggw_t = np.multiply(np.multiply(self.__gw_t, self.__gw_t), delta_w)
        gw_t_tau = self.__gw_t + self.__lambda * ggw_t
        self.set_result(gw_t_tau - delta_w, lambda x, y: x + y
                        if x is not None else y)
        return netEncapsulation(
            Parameter_Server,
            QuantizedPack(self.node_id, gw_t_tau, DC_QSGDClient.codex))
示例#4
0
 def receive_blocks(self, package: IPack):
     self.__global_weights -= package.content
     reply = NormalPack(
         Parameter_Server,
         self.__global_weights.astype(
             GlobalSettings.get_params(Low_Precision_Server)))
     return netEncapsulation(package.node_id, reply)
示例#5
0
 def update_blocks(
         self, block_weight: BlockWeight) -> netEncapsulation[NormalPack]:
     package = NormalPack(
         self.node_id,
         block_weight.content.astype(
             GlobalSettings.get_params(Full_Precision_Client)))
     return netEncapsulation(Parameter_Server, package)
示例#6
0
    def build():
        SGQPackage.__quant_codec = codec()
        SGQPackage.__quant_space = build_quantization_space(
            2**int(GlobalSettings.get_params(Quantization_Resolution_Client)) -
            1)
        SGQPackage.__quant_code = []
        i = -1

        while len(SGQPackage.__quant_code) != len(SGQPackage.__quant_space):
            i += 1
            flag = True
            for c in SGQPackage.__quant_space:
                if abs(c - i) < 1e-3:
                    flag = False
                    break
            if flag:
                SGQPackage.__quant_code.append(i)
        SGQPackage.__quant_codec.set_codec(SGQPackage.__quant_code)
示例#7
0
class QuantizedClient(Codec):
    q_space = build_quantization_space(
        int(GlobalSettings.get_params(Quantization_Resolution_Client)))

    def __init__(self, node_id):
        super().__init__(node_id)

    def dispose(self):
        pass

    def update_blocks(
            self,
            block_weight: BlockWeight) -> netEncapsulation[QuantizedPack]:
        package = QuantizedPack(
            self.node_id, *Q_g(block_weight.content, QuantizedClient.q_space))
        return netEncapsulation(Parameter_Server, package)

    def receive_blocks(self, package: IPack):
        self.set_result(package.content, lambda x, y: y)
示例#8
0
class QuantizedParaServer(Codec):
    codex = quant(
        build_quantization_space(
            int(GlobalSettings.get_params(Quantization_Resolution_Server))))

    def __init__(self, node_id):
        super().__init__(node_id)
        self.__global_weights: np.ndarray = 0

    def dispose(self):
        pass

    def update_blocks(self, block_weight: BlockWeight):
        pass

    def receive_blocks(self, package: IPack):
        self.__global_weights -= package.decode(QuantizedClient.codex)
        reply = QuantizedPack(Parameter_Server, self.__global_weights,
                              QuantizedParaServer.codex)
        return netEncapsulation(package.node_id, reply)