Esempio n. 1
0
    def __init__(self, manager, port, parameter_total, key_range):
        if key_range.Begin < 0 or key_range.End <= key_range.Begin or key_range.End > parameter_total:
            logger.error(
                "invalid key range [{}, {}), total parameter {}".format(
                    key_range.Begin, key_range.End, parameter_total))
            sys.exit(1)
        self.node = Node()
        self.node.Index = -1
        self.node.Host = get_host()
        self.node.Ready = True
        self.node.Role = WORKER
        self.van = ZMQ()
        self.van.Init(port)
        self.parameterTotal = parameter_total
        self.servers = []
        self.keyRange = key_range
        self.parameters = [Value()] * parameter_total
        self.pullResponded = ExpiringDict(max_len=1000,
                                          max_age_seconds=600)  # pull请求是否有返回

        # 向ServerManager注册worker,获取Servers
        msg = Message()
        msg.Sender.Host = self.node.Host
        msg.Receiver.Host = manager
        msg.Command = ADD_WORKER
        msgId = self.van.Send(msg)

        # 开始接收server集群的变更信息
        th = self.ReceiveThread(self)
        th.setDaemon(True)
        th.start()

        # 等worker注册成功
        if self.van.WaitAck(msgId, 5) is Node:
            logger.error("regist worker failed")
            sys.exit(1)

        # 必须等待取到Servers,否则程序就退出
        for i in xrange(10):
            if len(self.servers) == 0:
                time.sleep(1)
            else:
                break
        if len(self.servers) == 0:
            logger.error("could not get servers")
            sys.exit(1)

        logger.info("init ps client ok")
Esempio n. 2
0
 def __init__(self, manager, port, parameter_total, key_range,
              init_param_func, valuesEachKey):
     self.psClient = PsClient(manager, port, parameter_total, key_range)
     self.KeyRange = key_range
     # 初始化参数,从PS上取
     msgId = self.psClient.Pull()
     pullOk = self.psClient.WaitPull(msgId, 0)
     w = self.psClient.GetAllParameter()
     if not pullOk or w[key_range.Begin] is None or len(
             w[key_range.Begin].Values) == 0:
         initValues = []
         for i in xrange(parameter_total):
             value = Value()
             value.Values.extend([init_param_func()] * valuesEachKey)
             initValues.append(value)
         self.psClient.UpdateLocalParameter(initValues)
         msgId = self.psClient.Push()  # 把初始化好的参数push到PS集群上
         self.psClient.WaitPush(msgId, 5)
         logger.info("init parameter by init function")
     else:
         logger.info("init parameter from ps")
Esempio n. 3
0
def TrainLrWithGD(corpusFile, splitNum, splitIndex, epoch, batch, eta, manager,
                  port, ParameterTotal, KeyRange, synRound):
    begin = time.time()
    lr = LR(manager, port, ParameterTotal, KeyRange, random.random, 1)
    iter = 1
    WaitedPullMsgId = 0
    use_time = []
    for ep in xrange(epoch):
        logger.info("epoch={}".format(ep))
        corpusGenerator = CorpusGenerator(corpusFile, splitNum, splitIndex,
                                          ParameterTotal)
        xBatch = []
        yBatch = []
        for X, Y in corpusGenerator:
            xBatch.append(X)
            yBatch.append(Y)
            if len(xBatch) >= batch:
                if WaitedPullMsgId > 0:
                    if not lr.psClient.WaitPull(WaitedPullMsgId,
                                                1):  # 等之前的Pull命令完成
                        logger.error("wait pull timeout")
                msgId = lr.psClient.Pull()
                if iter % synRound == 0:
                    WaitedPullMsgId = msgId
                iter += 1
                t1 = time.time()
                x = np.array(xBatch)
                w = []
                for value in lr.psClient.GetAllParameter():
                    if len(value.Values) >= 1:
                        w.append(value.Values[0])
                    else:
                        logger.error(
                            "parameters of one key less than 2: {}".format(
                                len(value.Values)))
                w = np.array(w)
                y_hat = lr.fn(w, x)
                y = np.array(yBatch).reshape(len(yBatch), 1)
                g = lr.grad(y, y_hat,
                            x[:, KeyRange.Begin:KeyRange.End])  # 只需要计算部分梯度
                w[KeyRange.Begin:KeyRange.
                  End] -= eta * g  # 梯度下降法的核心公式,只更新自己负责的区间段
                Values = []
                for i in xrange(KeyRange.Begin, KeyRange.End):
                    value = Value()
                    value.Values.append(w[i])
                    Values.append(value)
                t2 = time.time()
                use_time.append(t2 - t1)
                lr.psClient.UpdateLocalRangedParameter(Values)
                lr.psClient.Push()
                xBatch = []
                yBatch = []
        logger.debug("update paramter {} times, mean use time {}".format(
            len(use_time), np.mean(np.array(use_time))))
        if len(xBatch) > 0:
            if WaitedPullMsgId > 0:
                if not lr.psClient.WaitPull(WaitedPullMsgId,
                                            1):  # 等之前的Pull命令完成
                    logger.error("wait pull timeout")
            msgId = lr.psClient.Pull()
            if iter % synRound == 0:
                WaitedPullMsgId = msgId
            iter += 1
            x = np.array(xBatch)
            w = []
            for value in lr.psClient.GetAllParameter():
                if len(value.Values) >= 1:
                    w.append(value.Values[0])
                else:
                    logger.error(
                        "parameters of one key less than 2: {}".format(
                            len(value.Values)))
            w = np.array(w)
            y_hat = lr.fn(w, x)
            y = np.array(yBatch).reshape(len(yBatch), 1)
            g = lr.grad(y, y_hat, x[:,
                                    KeyRange.Begin:KeyRange.End])  # 只需要计算部分梯度
            w[KeyRange.Begin:KeyRange.End] -= eta * g  # 梯度下降法的核心公式,只更新自己负责的区间段
            Values = []
            for i in xrange(KeyRange.Begin, KeyRange.End):
                value = Value()
                value.Values.append(w[i])
                Values.append(value)
            lr.psClient.UpdateLocalRangedParameter(Values)
            lr.psClient.Push()

    logger.info(
        "train lr with gd finished, use {} seconds".format(time.time() -
                                                           begin))
    return lr
Esempio n. 4
0
def TrainLrWithFTRL(corpusFile, splitNum, splitIndex, epoch, batch, manager,
                    port, ParameterTotal, KeyRange, alpha, beta, l1, l2,
                    synRound):
    begin = time.time()
    lr = LR(manager, port, ParameterTotal, KeyRange, init_zero,
            2)  # TRL中的z和n初始化为0,注意n一定不能是负数
    iter = 1
    WaitedPullMsgId = 0
    use_time = []
    for ep in xrange(epoch):
        logger.info("epoch={}".format(ep))
        corpusGenerator = CorpusGenerator(corpusFile, splitNum, splitIndex,
                                          ParameterTotal)
        xBatch = []
        yBatch = []
        for X, Y in corpusGenerator:
            xBatch.append(X)
            yBatch.append(Y)
            if len(xBatch) >= batch:
                if WaitedPullMsgId > 0:
                    if not lr.psClient.WaitPull(WaitedPullMsgId,
                                                1):  # 等之前的Pull命令完成
                        logger.error("wait pull timeout")
                msgId = lr.psClient.Pull()
                if iter % synRound == 0:
                    WaitedPullMsgId = msgId
                iter += 1
                t1 = time.time()
                x = np.array(xBatch)
                z = []
                n = []
                for v in lr.psClient.GetAllParameter():
                    if len(v.Values) >= 2:
                        z.append(v.Values[0])
                        n.append(v.Values[1])
                    else:
                        logger.error(
                            "parameters of one key less than 2: {}".format(
                                len(v.Values)))
                z = np.array(z)
                n = np.array(n)
                # FTRL核心公式
                w = np.array([
                    0 if np.abs(z[i]) <= l1 else (np.sign(z[i]) * l1 - z[i]) /
                    (l2 + (beta + np.sqrt(n[i])) / alpha)
                    for i in xrange(len(z))
                ])
                # print "w after", w[KeyRange.Begin:min(KeyRange.Begin + 10, KeyRange.End)]
                y_hat = lr.fn(w, x)
                y = np.array(yBatch).reshape(len(yBatch), 1)
                g = lr.grad(y, y_hat,
                            x[:, KeyRange.Begin:KeyRange.End])  # 只需要计算部分梯度
                # print "g", g[0:min(10, g.shape[0])]
                sigma = (np.sqrt(n[KeyRange.Begin:KeyRange.End] + g * g) -
                         np.sqrt(n[KeyRange.Begin:KeyRange.End])) / alpha
                z[KeyRange.Begin:KeyRange.End] += g - sigma * w[
                    KeyRange.Begin:KeyRange.End]  # 只更新自己负责的区间段
                # print "z after", z[KeyRange.Begin:min(KeyRange.Begin + 10, KeyRange.End)]
                n[KeyRange.Begin:KeyRange.End] += g * g  # 只更新自己负责的区间段
                Values = []
                for i in xrange(KeyRange.Begin, KeyRange.End):
                    value = Value()
                    value.Values.extend([z[i], n[i]])
                    Values.append(value)
                t2 = time.time()
                use_time.append(t2 - t1)
                lr.psClient.UpdateLocalRangedParameter(Values)
                lr.psClient.Push()
                xBatch = []
                yBatch = []
        logger.debug("update paramter {} times, mean use time {}".format(
            len(use_time), np.mean(np.array(use_time))))
        if len(xBatch) > 0:
            if WaitedPullMsgId > 0:
                if not lr.psClient.WaitPull(WaitedPullMsgId,
                                            1):  # 等之前的Pull命令完成
                    logger.error("wait pull timeout")
            msgId = lr.psClient.Pull()
            if iter % synRound == 0:
                WaitedPullMsgId = msgId
            iter += 1
            x = np.array(xBatch)
            z = []
            n = []
            for v in lr.psClient.GetAllParameter():
                if len(v.Values) >= 2:
                    z.append(v.Values[0])
                    n.append(v.Values[1])
                else:
                    logger.error(
                        "parameters of one key less than 2: {}".format(
                            len(v.Values)))
            z = np.array(z)
            n = np.array(n)
            # FTRL核心公式
            w = np.array([
                0 if np.abs(z[i]) <= l1 else (np.sign(z[i]) * l1 - z[i]) /
                (l2 + (beta + np.sqrt(n[i])) / alpha) for i in xrange(len(z))
            ])
            y_hat = lr.fn(w, x)
            y = np.array(yBatch).reshape(len(yBatch), 1)
            g = lr.grad(y, y_hat, x[:,
                                    KeyRange.Begin:KeyRange.End])  # 只需要计算部分梯度
            sigma = (np.sqrt(n[KeyRange.Begin:KeyRange.End] + g * g) -
                     np.sqrt(n[KeyRange.Begin:KeyRange.End])) / alpha
            z[KeyRange.Begin:KeyRange.
              End] += g - sigma * w[KeyRange.Begin:KeyRange.End]  # 只更新自己负责的区间段
            n[KeyRange.Begin:KeyRange.End] += g * g  # 只更新自己负责的区间段
            Values = []
            for i in xrange(KeyRange.Begin, KeyRange.End):
                value = Value()
                value.Values.extend([z[i], n[i]])
                Values.append(value)
            lr.psClient.UpdateLocalRangedParameter(Values)
            lr.psClient.Push()

    logger.info(
        "train lr with ftrl finished, use {} seconds".format(time.time() -
                                                             begin))
    return lr
Esempio n. 5
0
File: test.py Progetto: frankiegu/ps
        if WaitedPushMsgId > 0:
            client.WaitPush(WaitedPushMsgId, 0)

        # 等待上上次的Pull完成
        if WaitedPullMsgId > 0:
            client.WaitPull(WaitedPullMsgId, 0)

        # Pull
        msgId = client.Pull()
        if i % 3 == 0:
            WaitedPullMsgId = msgId

        # Push
        values = []
        for key in xrange(SelfRange.Begin, SelfRange.End):
            ele = Value()
            for j in xrange(ParameterCountOfKey):
                ele.Values.append(random.random())
            values.append(ele)
        client.UpdateLocalRangedParameter(values)
        msgId = client.Push()
        if i % 3 == 0:
            WaitedPushMsgId = msgId

        time.sleep(0.1)

    params = client.GetAllParameter()
    for i in xrange(SelfRange.Begin, SelfRange.End):
        l = params[i]
        print  i, "param", l.Values
        if i >= SelfRange.Begin + 10: