def __init__(self, manager, port, parameter_total, key_range): if key_range.Begin < 0 or key_range.End <= key_range.Begin or key_range.End > parameter_total: logger.error( "invalid key range [{}, {}), total parameter {}".format( key_range.Begin, key_range.End, parameter_total)) sys.exit(1) self.node = Node() self.node.Index = -1 self.node.Host = get_host() self.node.Ready = True self.node.Role = WORKER self.van = ZMQ() self.van.Init(port) self.parameterTotal = parameter_total self.servers = [] self.keyRange = key_range self.parameters = [Value()] * parameter_total self.pullResponded = ExpiringDict(max_len=1000, max_age_seconds=600) # pull请求是否有返回 # 向ServerManager注册worker,获取Servers msg = Message() msg.Sender.Host = self.node.Host msg.Receiver.Host = manager msg.Command = ADD_WORKER msgId = self.van.Send(msg) # 开始接收server集群的变更信息 th = self.ReceiveThread(self) th.setDaemon(True) th.start() # 等worker注册成功 if self.van.WaitAck(msgId, 5) is Node: logger.error("regist worker failed") sys.exit(1) # 必须等待取到Servers,否则程序就退出 for i in xrange(10): if len(self.servers) == 0: time.sleep(1) else: break if len(self.servers) == 0: logger.error("could not get servers") sys.exit(1) logger.info("init ps client ok")
def __init__(self, manager, port, parameter_total, key_range, init_param_func, valuesEachKey): self.psClient = PsClient(manager, port, parameter_total, key_range) self.KeyRange = key_range # 初始化参数,从PS上取 msgId = self.psClient.Pull() pullOk = self.psClient.WaitPull(msgId, 0) w = self.psClient.GetAllParameter() if not pullOk or w[key_range.Begin] is None or len( w[key_range.Begin].Values) == 0: initValues = [] for i in xrange(parameter_total): value = Value() value.Values.extend([init_param_func()] * valuesEachKey) initValues.append(value) self.psClient.UpdateLocalParameter(initValues) msgId = self.psClient.Push() # 把初始化好的参数push到PS集群上 self.psClient.WaitPush(msgId, 5) logger.info("init parameter by init function") else: logger.info("init parameter from ps")
def TrainLrWithGD(corpusFile, splitNum, splitIndex, epoch, batch, eta, manager, port, ParameterTotal, KeyRange, synRound): begin = time.time() lr = LR(manager, port, ParameterTotal, KeyRange, random.random, 1) iter = 1 WaitedPullMsgId = 0 use_time = [] for ep in xrange(epoch): logger.info("epoch={}".format(ep)) corpusGenerator = CorpusGenerator(corpusFile, splitNum, splitIndex, ParameterTotal) xBatch = [] yBatch = [] for X, Y in corpusGenerator: xBatch.append(X) yBatch.append(Y) if len(xBatch) >= batch: if WaitedPullMsgId > 0: if not lr.psClient.WaitPull(WaitedPullMsgId, 1): # 等之前的Pull命令完成 logger.error("wait pull timeout") msgId = lr.psClient.Pull() if iter % synRound == 0: WaitedPullMsgId = msgId iter += 1 t1 = time.time() x = np.array(xBatch) w = [] for value in lr.psClient.GetAllParameter(): if len(value.Values) >= 1: w.append(value.Values[0]) else: logger.error( "parameters of one key less than 2: {}".format( len(value.Values))) w = np.array(w) y_hat = lr.fn(w, x) y = np.array(yBatch).reshape(len(yBatch), 1) g = lr.grad(y, y_hat, x[:, KeyRange.Begin:KeyRange.End]) # 只需要计算部分梯度 w[KeyRange.Begin:KeyRange. End] -= eta * g # 梯度下降法的核心公式,只更新自己负责的区间段 Values = [] for i in xrange(KeyRange.Begin, KeyRange.End): value = Value() value.Values.append(w[i]) Values.append(value) t2 = time.time() use_time.append(t2 - t1) lr.psClient.UpdateLocalRangedParameter(Values) lr.psClient.Push() xBatch = [] yBatch = [] logger.debug("update paramter {} times, mean use time {}".format( len(use_time), np.mean(np.array(use_time)))) if len(xBatch) > 0: if WaitedPullMsgId > 0: if not lr.psClient.WaitPull(WaitedPullMsgId, 1): # 等之前的Pull命令完成 logger.error("wait pull timeout") msgId = lr.psClient.Pull() if iter % synRound == 0: WaitedPullMsgId = msgId iter += 1 x = np.array(xBatch) w = [] for value in lr.psClient.GetAllParameter(): if len(value.Values) >= 1: w.append(value.Values[0]) else: logger.error( "parameters of one key less than 2: {}".format( len(value.Values))) w = np.array(w) y_hat = lr.fn(w, x) y = np.array(yBatch).reshape(len(yBatch), 1) g = lr.grad(y, y_hat, x[:, KeyRange.Begin:KeyRange.End]) # 只需要计算部分梯度 w[KeyRange.Begin:KeyRange.End] -= eta * g # 梯度下降法的核心公式,只更新自己负责的区间段 Values = [] for i in xrange(KeyRange.Begin, KeyRange.End): value = Value() value.Values.append(w[i]) Values.append(value) lr.psClient.UpdateLocalRangedParameter(Values) lr.psClient.Push() logger.info( "train lr with gd finished, use {} seconds".format(time.time() - begin)) return lr
def TrainLrWithFTRL(corpusFile, splitNum, splitIndex, epoch, batch, manager, port, ParameterTotal, KeyRange, alpha, beta, l1, l2, synRound): begin = time.time() lr = LR(manager, port, ParameterTotal, KeyRange, init_zero, 2) # TRL中的z和n初始化为0,注意n一定不能是负数 iter = 1 WaitedPullMsgId = 0 use_time = [] for ep in xrange(epoch): logger.info("epoch={}".format(ep)) corpusGenerator = CorpusGenerator(corpusFile, splitNum, splitIndex, ParameterTotal) xBatch = [] yBatch = [] for X, Y in corpusGenerator: xBatch.append(X) yBatch.append(Y) if len(xBatch) >= batch: if WaitedPullMsgId > 0: if not lr.psClient.WaitPull(WaitedPullMsgId, 1): # 等之前的Pull命令完成 logger.error("wait pull timeout") msgId = lr.psClient.Pull() if iter % synRound == 0: WaitedPullMsgId = msgId iter += 1 t1 = time.time() x = np.array(xBatch) z = [] n = [] for v in lr.psClient.GetAllParameter(): if len(v.Values) >= 2: z.append(v.Values[0]) n.append(v.Values[1]) else: logger.error( "parameters of one key less than 2: {}".format( len(v.Values))) z = np.array(z) n = np.array(n) # FTRL核心公式 w = np.array([ 0 if np.abs(z[i]) <= l1 else (np.sign(z[i]) * l1 - z[i]) / (l2 + (beta + np.sqrt(n[i])) / alpha) for i in xrange(len(z)) ]) # print "w after", w[KeyRange.Begin:min(KeyRange.Begin + 10, KeyRange.End)] y_hat = lr.fn(w, x) y = np.array(yBatch).reshape(len(yBatch), 1) g = lr.grad(y, y_hat, x[:, KeyRange.Begin:KeyRange.End]) # 只需要计算部分梯度 # print "g", g[0:min(10, g.shape[0])] sigma = (np.sqrt(n[KeyRange.Begin:KeyRange.End] + g * g) - np.sqrt(n[KeyRange.Begin:KeyRange.End])) / alpha z[KeyRange.Begin:KeyRange.End] += g - sigma * w[ KeyRange.Begin:KeyRange.End] # 只更新自己负责的区间段 # print "z after", z[KeyRange.Begin:min(KeyRange.Begin + 10, KeyRange.End)] n[KeyRange.Begin:KeyRange.End] += g * g # 只更新自己负责的区间段 Values = [] for i in xrange(KeyRange.Begin, KeyRange.End): value = Value() value.Values.extend([z[i], n[i]]) Values.append(value) t2 = time.time() use_time.append(t2 - t1) lr.psClient.UpdateLocalRangedParameter(Values) lr.psClient.Push() xBatch = [] yBatch = [] logger.debug("update paramter {} times, mean use time {}".format( len(use_time), np.mean(np.array(use_time)))) if len(xBatch) > 0: if WaitedPullMsgId > 0: if not lr.psClient.WaitPull(WaitedPullMsgId, 1): # 等之前的Pull命令完成 logger.error("wait pull timeout") msgId = lr.psClient.Pull() if iter % synRound == 0: WaitedPullMsgId = msgId iter += 1 x = np.array(xBatch) z = [] n = [] for v in lr.psClient.GetAllParameter(): if len(v.Values) >= 2: z.append(v.Values[0]) n.append(v.Values[1]) else: logger.error( "parameters of one key less than 2: {}".format( len(v.Values))) z = np.array(z) n = np.array(n) # FTRL核心公式 w = np.array([ 0 if np.abs(z[i]) <= l1 else (np.sign(z[i]) * l1 - z[i]) / (l2 + (beta + np.sqrt(n[i])) / alpha) for i in xrange(len(z)) ]) y_hat = lr.fn(w, x) y = np.array(yBatch).reshape(len(yBatch), 1) g = lr.grad(y, y_hat, x[:, KeyRange.Begin:KeyRange.End]) # 只需要计算部分梯度 sigma = (np.sqrt(n[KeyRange.Begin:KeyRange.End] + g * g) - np.sqrt(n[KeyRange.Begin:KeyRange.End])) / alpha z[KeyRange.Begin:KeyRange. End] += g - sigma * w[KeyRange.Begin:KeyRange.End] # 只更新自己负责的区间段 n[KeyRange.Begin:KeyRange.End] += g * g # 只更新自己负责的区间段 Values = [] for i in xrange(KeyRange.Begin, KeyRange.End): value = Value() value.Values.extend([z[i], n[i]]) Values.append(value) lr.psClient.UpdateLocalRangedParameter(Values) lr.psClient.Push() logger.info( "train lr with ftrl finished, use {} seconds".format(time.time() - begin)) return lr
if WaitedPushMsgId > 0: client.WaitPush(WaitedPushMsgId, 0) # 等待上上次的Pull完成 if WaitedPullMsgId > 0: client.WaitPull(WaitedPullMsgId, 0) # Pull msgId = client.Pull() if i % 3 == 0: WaitedPullMsgId = msgId # Push values = [] for key in xrange(SelfRange.Begin, SelfRange.End): ele = Value() for j in xrange(ParameterCountOfKey): ele.Values.append(random.random()) values.append(ele) client.UpdateLocalRangedParameter(values) msgId = client.Push() if i % 3 == 0: WaitedPushMsgId = msgId time.sleep(0.1) params = client.GetAllParameter() for i in xrange(SelfRange.Begin, SelfRange.End): l = params[i] print i, "param", l.Values if i >= SelfRange.Begin + 10: