Exemplo n.º 1
0
    def process_request(self, worker_id, message):
        
        # override PTServer class method, for training related request
        
        reply = PTServer.process_request(self, worker_id, message)
        
        if reply != None:
            return reply
        else:
            pass
            
        try:
            valid = self.valid['%s' % worker_id]
            amount = self.uidx['%s' % worker_id]
            adj_lr = self.adj_lr['%s' % worker_id]
        except KeyError:
            self.valid['%s' % worker_id] = False
            self.adj_lr['%s' % worker_id] = False
            self.uidx['%s' % worker_id] = 0
            self.adj_lr = self.adj_lr.fromkeys(self.adj_lr, True) # when a new worker joins
        
        if message == 'next':
            
            if self.start_time is None:
                self.start_time = time.time()
                
            if sum(self.uidx.values()) >= self.max_mb: # stop when finish all epochs
                print "[Server] Total training time %.2fh" % ((time.time() - self.start_time)/3600.0)
                reply = 'stop'
                
            elif self.valid['%s' % worker_id]:
                self.valid['%s' % worker_id] = False
                reply = 'val'
                
            elif self.adj_lr['%s' % worker_id]:
                self.adj_lr['%s' % worker_id] = False
                reply = 'adjust_lr'
                
            else:
                reply = 'train' 
                
        elif 'done' in message:
            
            self.uidx['%s' % worker_id] += message['done']
            #print '[Server] uidx %d' % sum(self.uidx.values())

        elif message == 'uepoch':
        
            reply = [self.uepoch, len(self.worker_comm)]
                
        if message in ['next', 'uepoch'] or 'done' in message:       
            
            now_uidx = sum(self.uidx.values())
            self.uepoch = int(now_uidx/self.validFreq)
            if self.last_uepoch != self.uepoch:
                #print "[Server] now global epoch %d" % self.uepoch
                self.last_uepoch = self.uepoch 
                self.adj_lr = self.adj_lr.fromkeys(self.adj_lr, True) # when a epoch is finished
                #self.valid = self.valid.fromkeys(self.valid, True)
                self.valid["%s" % self.first_worker_id] = True # only the first worker validates
                
                # tunning server alpha
                a_step1, a_step2, a_step3 = self.config['alpha_step']
                if self.uepoch>a_step1 and self.uepoch< a_step2:
                    step_idx = 1
                elif self.uepoch>a_step2 and self.uepoch< a_step3:
                    step_idx = 2
                elif self.uepoch>a_step3:
                    step_idx = 3
                else:
                    step_idx = 0
                self.exchanger.alpha=self.config['server_alpha'] - self.config['alpha_minus']*step_idx
                print 'server alpha changed to %f' % self.exchanger.alpha
                
                
            if self.last == None:
                self.last = float(time.time())
                
            if now_uidx - self.last_uidx >= 400:
                
                now = float(time.time())
        
                print '[Server] %d time per 5120 images: %.2f s' % \
                        (self.uepoch, (now - self.last)/(now_uidx - self.last_uidx)*40.0)
        
                self.last_uidx = now_uidx
                self.last = now
            
                
        
        return reply
Exemplo n.º 2
0
    def process_request(self, worker_id, message):

        # override PTServer class method, for training related request

        reply = PTServer.process_request(self, worker_id, message)

        if reply != None:
            return reply
        else:
            pass

        try:
            valid = self.valid['%s' % worker_id]
            amount = self.uidx['%s' % worker_id]
            adj_lr = self.adj_lr['%s' % worker_id]
        except KeyError:
            self.valid['%s' % worker_id] = False
            self.adj_lr['%s' % worker_id] = False
            self.uidx['%s' % worker_id] = 0
            self.adj_lr = self.adj_lr.fromkeys(self.adj_lr,
                                               True)  # when a new worker joins

        if message == 'next':

            if self.start_time is None:
                self.start_time = time.time()

            if sum(self.uidx.values()
                   ) >= self.max_mb:  # stop when finish all epochs
                print "[Server] Total training time %.2fh" % (
                    (time.time() - self.start_time) / 3600.0)
                reply = 'stop'

            elif self.valid['%s' % worker_id]:
                self.valid['%s' % worker_id] = False
                reply = 'val'

            elif self.adj_lr['%s' % worker_id]:
                self.adj_lr['%s' % worker_id] = False
                reply = 'adjust_lr'

            else:
                reply = 'train'

        elif 'done' in message:

            self.uidx['%s' % worker_id] += message['done']
            #print '[Server] uidx %d' % sum(self.uidx.values())

        elif message == 'uepoch':

            reply = [self.uepoch, len(self.worker_comm)]

        if message in ['next', 'uepoch'] or 'done' in message:

            now_uidx = sum(self.uidx.values())
            self.uepoch = int(now_uidx / self.validFreq)
            if self.last_uepoch != self.uepoch:
                #print "[Server] now global epoch %d" % self.uepoch
                self.last_uepoch = self.uepoch
                self.adj_lr = self.adj_lr.fromkeys(
                    self.adj_lr, True)  # when a epoch is finished
                #self.valid = self.valid.fromkeys(self.valid, True)
                self.valid[
                    "%s" % self.
                    first_worker_id] = True  # only the first worker validates

                # tunning server alpha
                a_step1, a_step2, a_step3 = self.config['alpha_step']
                if self.uepoch > a_step1 and self.uepoch < a_step2:
                    step_idx = 1
                elif self.uepoch > a_step2 and self.uepoch < a_step3:
                    step_idx = 2
                elif self.uepoch > a_step3:
                    step_idx = 3
                else:
                    step_idx = 0
                self.exchanger.alpha = self.config[
                    'server_alpha'] - self.config['alpha_minus'] * step_idx
                print 'server alpha changed to %f' % self.exchanger.alpha

            if self.last == None:
                self.last = float(time.time())

            if now_uidx - self.last_uidx >= 400:

                now = float(time.time())

                print '[Server] %d time per 5120 images: %.2f s' % \
                        (self.uepoch, (now - self.last)/(now_uidx - self.last_uidx)*40.0)

                self.last_uidx = now_uidx
                self.last = now

        return reply