def update_parameters_opt(self): # sync with ps try: gStatus = self.ps.getGlobalStatus() except: gStatus = -1 staleness = gStatus - self.status['GlobalStep'] if gStatus == self.status['GlobalStep']: self.status['LocalHit'] += 1 return version_stamp_diff = gStatus - self.service_handler.getStatus() if version_stamp_diff < self.staleness_threshold and self.service_handler.getStatus( ) >= 0 and self.service_handler.getStatus( ) > self.status['GlobalStep']: self.status['RemoteHit'] += 1 model = comp.deprocess(self.service_handler.getModel(), self.tensorgraph_shape) self.tensorgraph.put_parameters(model) return self.sw.reset() try: text = self.ps.download() self.sw.accumulate('download') model = comp.deprocess(text, self.tensorgraph_shape) self.sw.accumulate('deprocess') except: del self.ps self.ps = init_conn(cluster_spec['ps'][0]['IP'], cluster_spec['ps'][0]['Port']) return self.sw.reset() self.tensorgraph.put_parameters(model) self.sw.accumulate('put para') return staleness
def kill_child_processes(): ps = init_conn(cluster_spec['ps'][0]['IP'], cluster_spec['ps'][0]['Port']) print "update count : %d" % ps.getGlobalStatus() parent_pid = os.getpid() try: parent = psutil.Process(parent_pid) except psutil.NoSuchProcess: return children = parent.children(recursive=True) for process in children: process.send_signal(signal.SIGINT)
def __init__(self, cn_id, start, length): self.id = cn_id self.batch_size = 200 self.num_epochs = 10 self.train_dataset, self.train_labels, self.valid_dataset, self.valid_labels, self.test_dataset, self.test_labels = open_cifar10_dataset( start, length) gpu_config = gpu_split(len(cluster_spec['cn'])) self.tensorgraph = CNN(gpu_config) self.tensorgraph_shape = self.tensorgraph.get_configure() # establish connection with parameter server to acquire store service self.ps = init_conn(cluster_spec['ps'][0]['IP'], cluster_spec['ps'][0]['Port']) self.sw = StopWatch() self.logging = True
def __init__(self, cn_id, start, length, receive_service=True, uploading_in_background=True): self.id = cn_id self.batch_size = 200 self.num_epochs = 3 self.staleness_threshold = 3 self.train_dataset, self.train_labels, self.valid_dataset, self.valid_labels, self.test_dataset, self.test_labels = open_cifar10_dataset( start, length) gpu_config = gpu_split(len(cluster_spec['cn'])) self.tensorgraph = CNN(gpu_config) self.tensorgraph_shape = self.tensorgraph.get_configure() # establish connection with parameter server to acquire store service self.ps = init_conn(cluster_spec['ps'][0]['IP'], cluster_spec['ps'][0]['Port']) if receive_service: # start a model receiver service self.service_handler = Handler() service = threading.Thread(target=receive, args=(cluster_spec['cn'][cn_id]['IP'], cluster_spec['cn'][cn_id]['Port'], self.service_handler)) service.daemon = True service.start() self.update_parameters = self.update_parameters_opt else: self.update_parameters = self.update_parameters_ori # switch between origin or optimized mode for uploading parameters self.lock = threading.Lock() if uploading_in_background: self.upload_parameters = self.upload_parameters_opt else: self.upload_parameters = self.upload_parameters_ori self.sw = StopWatch() self.logging = True self.status = { 'GlobalStep': -1, 'LocalStep': 0, 'LocalHit': 0, 'RemoteHit': 0 }
def update_parameters_ori(self): self.sw.reset() try: gStatus = self.ps.getGlobalStatus() staleness = gStatus - self.status['GlobalStep'] text = self.ps.download() self.sw.accumulate('download') model = comp.deprocess(text, self.tensorgraph_shape) self.sw.accumulate('deprocess') except: del self.ps self.ps = init_conn(cluster_spec['ps'][0]['IP'], cluster_spec['ps'][0]['Port']) return self.tensorgraph.put_parameters(model) self.sw.accumulate('put para') return staleness
# create parameter servers if args.server: for i in range(ps_num): process = multiprocessing.Process(target=ps_job, args=(i, args.predict)) process.start() ps_processes.append(process) # create computing nodes training_set_size = 30000 length = training_set_size / cn_num for i in range(cn_num): process = multiprocessing.Process(target=cn_job, args=(i, i * length, length, args.predict, args.background)) process.start() cn_processes.append(process) signal.signal(signal.SIGINT, sig_handler) # wait for training is done for i in range(cn_num): cn_processes[i].join() ps = init_conn(cluster_spec['ps'][0]['IP'], cluster_spec['ps'][0]['Port']) print "update count : %d" % ps.getGlobalStatus() if args.predict: ps.getUploadRecord() kill_child_processes()
from thrift_conn import init_conn import time conn = init_conn("127.0.0.1", 50001) for i in range(10): print i conn.upload("m" + str(i)) time.sleep(1)