Exemplo n.º 1
0
 def update_parameters_opt(self):
     # sync with ps
     try:
         gStatus = self.ps.getGlobalStatus()
     except:
         gStatus = -1
     staleness = gStatus - self.status['GlobalStep']
     if gStatus == self.status['GlobalStep']:
         self.status['LocalHit'] += 1
         return
     version_stamp_diff = gStatus - self.service_handler.getStatus()
     if version_stamp_diff < self.staleness_threshold and self.service_handler.getStatus(
     ) >= 0 and self.service_handler.getStatus(
     ) > self.status['GlobalStep']:
         self.status['RemoteHit'] += 1
         model = comp.deprocess(self.service_handler.getModel(),
                                self.tensorgraph_shape)
         self.tensorgraph.put_parameters(model)
         return
     self.sw.reset()
     try:
         text = self.ps.download()
         self.sw.accumulate('download')
         model = comp.deprocess(text, self.tensorgraph_shape)
         self.sw.accumulate('deprocess')
     except:
         del self.ps
         self.ps = init_conn(cluster_spec['ps'][0]['IP'],
                             cluster_spec['ps'][0]['Port'])
         return
     self.sw.reset()
     self.tensorgraph.put_parameters(model)
     self.sw.accumulate('put para')
     return staleness
Exemplo n.º 2
0
 def upload(self, cnid, u_grads):
     self.lock.acquire()
     grads = comp.deprocess(u_grads, self.cnn_graph_shape)
     self.cnn_graph.put_gradients(grads)
     self.update_count += 1
     self.lock.release()
     return self.update_count
Exemplo n.º 3
0
 def update_parameters(self):
     self.sw.reset()
     text = self.ps.download()
     self.sw.accumulate('download')
     model = comp.deprocess(text, self.tensorgraph_shape)
     self.sw.accumulate('deprocess')
     self.tensorgraph.put_parameters(model)
     self.sw.accumulate('put para')
Exemplo n.º 4
0
 def update_parameters_ori(self):
     self.sw.reset()
     try:
         gStatus = self.ps.getGlobalStatus()
         staleness = gStatus - self.status['GlobalStep']
         text = self.ps.download()
         self.sw.accumulate('download')
         model = comp.deprocess(text, self.tensorgraph_shape)
         self.sw.accumulate('deprocess')
     except:
         del self.ps
         self.ps = init_conn(cluster_spec['ps'][0]['IP'],
                             cluster_spec['ps'][0]['Port'])
         return
     self.tensorgraph.put_parameters(model)
     self.sw.accumulate('put para')
     return staleness
Exemplo n.º 5
0
import thriftpy
import time
test_thrift = thriftpy.load("benchmarkmessage.thrift",
                            module_name="test_thrift")

from thriftpy.rpc import make_client
client = make_client(test_thrift.Helloworld, '127.0.0.1', 6000)

import compression as comp
import sys
sys.path.append('../../src')
from CIFAR10_CNN import CIFAR10_CNN as CNN

cnn_graph = CNN()
text = comp.preprocess(cnn_graph.get_parameters())
print len(text)

ts = time.time()
ret = client.echo(text)
te = time.time()
print "networking elapsed time : %fs" % (te - ts)

model = comp.deprocess(ret, cnn_graph.get_configure())