def do_POST(self): """ Handles POST queries, which are usually Thrift messages. """ protocol_factory = TJSONProtocol.TJSONProtocolFactory() input_protocol_factory = protocol_factory output_protocol_factory = protocol_factory # Get Thrift API function name to print to the log output. itrans = TTransport.TFileObjectTransport(self.rfile) itrans = TTransport.TBufferedTransport( itrans, int(self.headers['Content-Length'])) iprot = input_protocol_factory.getProtocol(itrans) fname, _, _ = iprot.readMessageBegin() client_host, client_port, is_ipv6 = \ RequestHandler._get_client_host_port(self.client_address) self.auth_session = self.__check_session_cookie() LOG.info("%s:%s -- [%s] POST %s@%s", client_host if not is_ipv6 else '[' + client_host + ']', client_port, self.auth_session.user if self.auth_session else "Anonymous", self.path, fname) # Create new thrift handler. checker_md_docs = self.server.checker_md_docs checker_md_docs_map = self.server.checker_md_docs_map version = self.server.version cstringio_buf = itrans.cstringio_buf.getvalue() itrans = TTransport.TMemoryBuffer(cstringio_buf) iprot = input_protocol_factory.getProtocol(itrans) otrans = TTransport.TMemoryBuffer() oprot = output_protocol_factory.getProtocol(otrans) if self.server.manager.is_enabled and \ not self.path.endswith(('/Authentication', '/Configuration', '/ServerInfo')) and \ not self.auth_session: # Bail out if the user is not authenticated... # This response has the possibility of melting down Thrift clients, # but the user is expected to properly authenticate first. LOG.debug( "%s:%s Invalid access, credentials not found " "- session refused.", client_host if not is_ipv6 else '[' + client_host + ']', str(client_port)) self.send_thrift_exception("Error code 401: Unauthorized!", iprot, oprot, otrans) return # Authentication is handled, we may now respond to the user. try: product_endpoint, api_ver, request_endpoint = \ routing.split_client_POST_request(self.path) if product_endpoint is None and api_ver is None and\ request_endpoint is None: raise Exception("Invalid request endpoint path.") product = None if product_endpoint: # The current request came through a product route, and not # to the main endpoint. product = self.__check_prod_db(product_endpoint) version_supported = routing.is_supported_version(api_ver) if version_supported: major_version, _ = version_supported if major_version == 6: if request_endpoint == 'Authentication': auth_handler = AuthHandler_v6( self.server.manager, self.auth_session, self.server.config_session) processor = AuthAPI_v6.Processor(auth_handler) elif request_endpoint == 'Configuration': conf_handler = ConfigHandler_v6( self.auth_session, self.server.config_session) processor = ConfigAPI_v6.Processor(conf_handler) elif request_endpoint == 'ServerInfo': server_info_handler = ServerInfoHandler_v6(version) processor = ServerInfoAPI_v6.Processor( server_info_handler) elif request_endpoint == 'Products': prod_handler = ProductHandler_v6( self.server, self.auth_session, self.server.config_session, product, version) processor = ProductAPI_v6.Processor(prod_handler) elif request_endpoint == 'CodeCheckerService': # This endpoint is a product's report_server. if not product: error_msg = "Requested CodeCheckerService on a " \ "nonexistent product: '{0}'." \ .format(product_endpoint) LOG.error(error_msg) raise ValueError(error_msg) if product_endpoint: # The current request came through a # product route, and not to the main endpoint. product = self.__check_prod_db(product_endpoint) acc_handler = ReportHandler_v6( self.server.manager, product.session_factory, product, self.auth_session, self.server.config_session, checker_md_docs, checker_md_docs_map, version, self.server.context) processor = ReportAPI_v6.Processor(acc_handler) else: LOG.debug("This API endpoint does not exist.") error_msg = "No API endpoint named '{0}'." \ .format(self.path) raise ValueError(error_msg) else: error_msg = "The API version you are using is not supported " \ "by this server (server API version: {0})!".format( get_version_str()) self.send_thrift_exception(error_msg, iprot, oprot, otrans) return processor.process(iprot, oprot) result = otrans.getvalue() self.send_response(200) self.send_header("content-type", "application/x-thrift") self.send_header("Content-Length", len(result)) self.end_headers() self.wfile.write(result) return except Exception as exn: LOG.warning(str(exn)) import traceback traceback.print_exc() cstringio_buf = itrans.cstringio_buf.getvalue() if cstringio_buf: itrans = TTransport.TMemoryBuffer(cstringio_buf) iprot = input_protocol_factory.getProtocol(itrans) self.send_thrift_exception(str(exn), iprot, oprot, otrans) return
def pass_through_dfe(size, data_in): """PassThrough DFE implementation.""" try: start_time = time.time() # Make socket socket = TSocket.TSocket('localhost', 9090) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(socket) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = PassThroughService.Client(protocol) print ('Creating a client:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Connect! start_time = time.time() transport.open() print ('Opening connection:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Allocate and send input streams to server start_time = time.time() address_data_in = client.malloc_float(size) client.send_data_float(address_data_in, data_in) print ('Sending input data:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Allocate memory for output stream on server start_time = time.time() address_data_out = client.malloc_float(size) print ('Allocating memory for output stream on server:\t%.5lfs'% (time.time() - start_time)) # Action default start_time = time.time() client.PassThrough(size, address_data_in, address_data_out) print ('Pass through time:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Get output stream from server start_time = time.time() data_out = client.receive_data_float(address_data_out, size) print ('Getting output stream:\t(size = %d bit)\t%.5lfs' % ((size * 32), (time.time() - start_time))) # Free allocated memory for streams on server start_time = time.time() client.free(address_data_in) client.free(address_data_out) print ('Freeing allocated memory for streams on server:\t%.5lfs' % (time.time() - start_time)) # Close! start_time = time.time() transport.close() print ('Closing connection:\t\t\t\t%.5lfs' % (time.time() - start_time)) except Thrift.TException, thrift_exceptiion: print '%s' % (thrift_exceptiion.message) sys.exit(-1)
def main(cfg, reqhandle, resphandle): if cfg.unix: if cfg.addr == "": sys.exit("invalid unix domain socket: {}".format(cfg.addr)) socket = TSocket.TSocket(unix_socket=cfg.addr) else: try: (host, port) = cfg.addr.rsplit(":", 1) if host == "": host = "localhost" socket = TSocket.TSocket(host=host, port=int(port)) except ValueError: sys.exit("invalid address: {}".format(cfg.addr)) transport = TRecordingTransport(socket, reqhandle, resphandle) if cfg.transport == "framed": transport = TTransport.TFramedTransport(transport) elif cfg.transport == "unframed": transport = TTransport.TBufferedTransport(transport) elif cfg.transport == "header": transport = THeaderTransport.THeaderTransport( transport, client_type=THeaderTransport.CLIENT_TYPE.HEADER, ) if cfg.headers is not None: pairs = cfg.headers.split(",") for p in pairs: key, value = p.split("=") transport.set_header(key, value) if cfg.protocol == "binary": transport.set_protocol_id(THeaderTransport.T_BINARY_PROTOCOL) elif cfg.protocol == "compact": transport.set_protocol_id(THeaderTransport.T_COMPACT_PROTOCOL) else: sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) else: sys.exit("unknown transport {0}".format(cfg.transport)) transport.open() if cfg.protocol == "binary": protocol = TBinaryProtocol.TBinaryProtocol(transport) elif cfg.protocol == "compact": protocol = TCompactProtocol.TCompactProtocol(transport) elif cfg.protocol == "json": protocol = TJSONProtocol.TJSONProtocol(transport) elif cfg.protocol == "finagle": protocol = TFinagleProtocol(transport, client_id="thrift-playground") else: sys.exit("unknown protocol {0}".format(cfg.protocol)) if cfg.service is not None: protocol = TMultiplexedProtocol.TMultiplexedProtocol(protocol, cfg.service) client = Example.Client(protocol) try: if cfg.method == "ping": client.ping() print("client: pinged") elif cfg.method == "poke": client.poke() print("client: poked") elif cfg.method == "add": if len(cfg.params) != 2: sys.exit("add takes 2 arguments, got: {0}".format(cfg.params)) a = int(cfg.params[0]) b = int(cfg.params[1]) v = client.add(a, b) print("client: added {0} + {1} = {2}".format(a, b, v)) elif cfg.method == "execute": param = Param( return_fields=cfg.params, the_works=TheWorks( field_1=True, field_2=0x7f, field_3=0x7fff, field_4=0x7fffffff, field_5=0x7fffffffffffffff, field_6=-1.5, field_7=u"string is UTF-8: \U0001f60e", field_8=b"binary is bytes: \x80\x7f\x00\x01", field_9={ 1: "one", 2: "two", 3: "three" }, field_10=[1, 2, 4, 8], field_11=set(["a", "b", "c"]), field_12=False, )) try: result = client.execute(param) print("client: executed {0}: {1}".format(param, result)) except AppException as e: print("client: execute failed with IDL Exception: {0}".format(e.why)) else: sys.exit("unknown method {0}".format(cfg.method)) except Thrift.TApplicationException as e: print("client exception: {0}: {1}".format(e.type, e.message)) if cfg.request is None: req = "".join(["%02X " % ord(x) for x in reqhandle.getvalue()]).strip() print("request: {}".format(req)) if cfg.response is None: resp = "".join(["%02X " % ord(x) for x in resphandle.getvalue()]).strip() print("response: {}".format(resp)) transport.close()
def getresultbyid(resid): socket = TSocket.TSocket(host, port) transport = TTransport.TBufferedTransport(socket) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = sdhashsrv.Client(protocol) transport.open() result = client.displayResult(resid) queryset = 'query' targetset = 'target' info = client.displayResultInfo(resid) stuff = info.split() if (info.count('--') == 1): queryset = stuff[0] targetset = stuff[2] else: queryset = stuff[0] targetset = queryset output = [] header = [] header.append('queryset') header.append('query') header.append('targetset') header.append('target') header.append('score') for line in result.split('\n'): cols = line.split('|') items = [] cforw = cols[0].count('/') cback = cols[0].count('\\') if (len(cols) == 3): items.append(queryset) if (len(cols[0]) > 50): if (cback > 0): fileparts = cols[0].rsplit('\\', 3) items.append('...\\' + fileparts[1] + '\\' + fileparts[2] + '\\' + fileparts[3]) if (cforw > 0): fileparts = cols[0].rsplit('/', 3) items.append('.../' + fileparts[1] + '/' + fileparts[2] + '/' + fileparts[3]) else: items.append('...' + cols[0][-50:]) else: items.append(cols[0]) items.append(targetset) cforw = cols[1].count('/') cback = cols[1].count('\\') if (len(cols[1]) > 50): if (cback > 0): fileparts = cols[1].rsplit('\\', 3) items.append('...\\' + fileparts[1] + '\\' + fileparts[2] + '\\' + fileparts[3]) if (cforw > 0): fileparts = cols[1].rsplit('/', 3) items.append('.../' + fileparts[1] + '/' + fileparts[2] + '/' + fileparts[3]) else: items.append(cols[1]) items.append(cols[2]) output.append(items) transport.close() return output
from thrift.transport import TSocket from thrift.transport import TTransport from thrift.protocol import TBinaryProtocol from h2oai_scoring import ScoringService from h2oai_scoring.ttypes import Row # ----------------------------------------------- # Name Type Range # ----------------------------------------------- # L3_S30_D3521 float64 [0.4, 1718.39] or None # ----------------------------------------------- socket = TSocket.TSocket('localhost', 9090) transport = TTransport.TBufferedTransport(socket) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = ScoringService.Client(protocol) transport.open() server_hash = client.get_hash() print('Scoring server hash: '.format(server_hash)) print('Scoring individual rows...') row1 = Row() row1.l3S30D3521 = 1660.7575097837375 # L3_S30_D3521 row2 = Row() row2.l3S30D3521 = 757.5741980572641 # L3_S30_D3521 row3 = Row()
def handler(event, context): start_time = time.time() worker_index = event['rank'] num_workers = event['num_workers'] host = event['host'] port = event['port'] size = event['size'] print('number of workers = {}'.format(num_workers)) print('worker index = {}'.format(worker_index)) print("host = {}".format(host)) print("port = {}".format(port)) print("size = {}".format(size)) # Set thrift connection # Make socket transport = TSocket.TSocket(host, port) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder t_client = ParameterServer.Client(protocol) # Connect! transport.open() # test thrift connection ps_client.ping(t_client) print("create and ping thrift server >>> HOST = {}, PORT = {}".format( host, port)) # register model ps_client.register_model(t_client, worker_index, MODEL_NAME, size, num_workers) ps_client.exist_model(t_client, MODEL_NAME) print("register and check model >>> name = {}, length = {}".format( MODEL_NAME, size)) # Training the Model train_start = time.time() iter_counter = 0 for epoch in range(NUM_EPOCHS): epoch_start = time.time() for batch_index in range(NUM_BATCHES): print("------worker {} epoch {} batch {}------".format( worker_index, epoch, batch_index)) batch_start = time.time() loss = 0.0 # pull latest model ps_client.can_pull(t_client, MODEL_NAME, iter_counter, worker_index) pull_start = time.time() latest_model = ps_client.pull_model(t_client, MODEL_NAME, iter_counter, worker_index) pull_time = time.time() - pull_start w_b_grad = np.random.rand(1, size).astype(np.double).flatten() # push gradient to PS ps_client.can_push(t_client, MODEL_NAME, iter_counter, worker_index) push_start = time.time() ps_client.push_grad(t_client, MODEL_NAME, w_b_grad, LEARNING_RATE, iter_counter, worker_index) push_time = time.time() - push_start ps_client.can_pull(t_client, MODEL_NAME, iter_counter + 1, worker_index) # sync all workers print( 'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, ' 'batch cost %.4f s: pull model cost %.4f s, push update cost %.4f s' % (epoch + 1, NUM_EPOCHS, batch_index, NUM_BATCHES, time.time() - train_start, loss, time.time() - epoch_start, time.time() - batch_start, pull_time, push_time)) iter_counter += 1 end_time = time.time() print("Elapsed time = {} s".format(end_time - start_time))
def connect(self, host=None, port=None, uri=None, timeout=10000): """ Connect method should be called before any operations. Server will be connected after connect return OK :type host: str :type port: str :type uri: str :type timeout: int :param timeout: (Optional) connection timeout, ms / 1000 :param host: (Optional) host of the server, default host is 127.0.0.1 :param port: (Optional) port of the server, default port is 9090 :param uri: (Optional) only support tcp proto now, default uri is `tcp://127.0.0.1:9090` :return: Status, indicate if connect is successful :rtype: Status """ if self.status and self.status == Status.SUCCESS: raise RepeatingConnectError("You have already connected!") config_uri = urlparse(config.THRIFTCLIENT_TRANSPORT) _uri = urlparse(uri) if uri else config_uri if not host: if _uri.scheme == 'tcp': host = _uri.hostname port = _uri.port or 9090 else: if uri: raise RuntimeError('Invalid parameter uri: {}'.format(uri)) raise RuntimeError( 'Invalid configuration for THRIFTCLIENT_TRANSPORT: {transport}' .format(transport=config.THRIFTCLIENT_TRANSPORT)) else: host = host port = port or 9090 self._transport = TSocket.TSocket(host, port) if timeout: self._transport.setTimeout(int(timeout)) if config.THRIFTCLIENT_BUFFERED: self._transport = TTransport.TBufferedTransport(self._transport) if config.THRIFTCLIENT_ZLIB: self._transport = TZlibTransport.TZlibTransport(self._transport) if config.THRIFTCLIENT_FRAMED: self._transport = TTransport.TFramedTransport(self._transport) if config.THRIFTCLIENT_PROTOCOL == Protocol.BINARY: protocol = TBinaryProtocol.TBinaryProtocol(self._transport) elif config.THRIFTCLIENT_PROTOCOL == Protocol.COMPACT: protocol = TCompactProtocol.TCompactProtocol(self._transport) elif config.THRIFTCLIENT_PROTOCOL == Protocol.JSON: protocol = TJSONProtocol.TJSONProtocol(self._transport) else: raise RuntimeError( "invalid configuration for THRIFTCLIENT_PROTOCOL: {protocol}". format(protocol=config.THRIFTCLIENT_PROTOCOL)) self._client = MilvusService.Client(protocol) try: self._transport.open() self.status = Status(Status.SUCCESS, 'Connected') return self.status except TTransport.TTransportException as e: self.status = Status(code=e.type, message=e.message) LOGGER.error(e) raise NotConnectError('Connection failed')
def get_client(self, host, port): trans = TSocket.TSocket(host, port) trans = TTransport.TBufferedTransport(trans) proto = TBinaryProtocol.TBinaryProtocolAccelerated(trans) client = KnnThriftService.Client(proto) return client, trans
# Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/1.10/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = config.get('default','SECRET_KEY') # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True ALLOWED_HOSTS = [] TRANSPORT = TSocket.TSocket(config.get('thrift','ADDRESS'), config.get('thrift','PORT')) TRANSPORT = TTransport.TBufferedTransport(TRANSPORT) PROTOCOL = TBinaryProtocol.TBinaryProtocol(TRANSPORT) TRANSPORT.open() # Application definition INSTALLED_APPS = [ 'home.apps.HomeConfig', 'django.contrib.admin', 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles',
def main(): parser = argparse.ArgumentParser() parser.add_argument('-server', nargs='+') parser.add_argument('-set', nargs='+') parser.add_argument('-get', nargs='+') parser.add_argument('-del', nargs='+') args = parser.parse_args() # Make socket if 'server' in args: host_port = getattr(args, 'server')[0].split(':') if len(host_port) != 2: sys.stderr.write( 'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>' ) return else: transport = TSocket.TSocket(host_port[0], host_port[1]) else: transport = TSocket.TSocket('127.0.0.1', 9090) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = Client(protocol) # Connect! transport.open() # Test! result = None if getattr(args, 'set') != None: key_value = getattr(args, 'set') if len(key_value) != 2: sys.stderr.write( 'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>' ) return else: result = set_kv(client, key_value[0], key_value[1]) if getattr(args, 'get') != None: key_file = getattr(args, 'get') if len(key_file) != 2: sys.stderr.write( 'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>' ) return else: result = get_k(client, key_file[0]) if result.error == 0: print result.value f = open(key_file[1], 'w') f.write(result.value) f.close if getattr(args, 'del') != None: key = getattr(args, 'del') if len(key) != 1: sys.stderr.write( 'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>' ) return else: result = del_k(client, key[0]) result_printer(result) # Close! transport.close()
def handler(event, context): startTs = time.time() num_features = event['num_features'] learning_rate = event["learning_rate"] batch_size = event["batch_size"] num_epochs = event["num_epochs"] validation_ratio = event["validation_ratio"] # Reading data from S3 bucket_name = event['bucket_name'] key = urllib.parse.unquote_plus(event['key'], encoding='utf-8') print(f"Reading training data from bucket = {bucket_name}, key = {key}") key_splits = key.split("_") worker_index = int(key_splits[0]) num_worker = int(key_splits[1]) # read file from s3 file = get_object(bucket_name, key).read().decode('utf-8').split("\n") print("read data cost {} s".format(time.time() - startTs)) parse_start = time.time() dataset = SparseDatasetWithLines(file, num_features) print("parse data cost {} s".format(time.time() - parse_start)) preprocess_start = time.time() dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(validation_ratio * dataset_size)) np.random.seed(42) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] train_set = [dataset[i] for i in train_indices] val_set = [dataset[i] for i in val_indices] print("preprocess data cost {} s".format(time.time() - preprocess_start)) # Set thrift connection # Make socket transport = TSocket.TSocket(constants.HOST, constants.PORT) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder t_client = ParameterServer.Client(protocol) # Connect! transport.open() # test thrift connection ps_client.ping(t_client) print("create and ping thrift server >>> HOST = {}, PORT = {}".format( constants.HOST, constants.PORT)) svm = SparseSVM(train_set, val_set, num_features, num_epochs, learning_rate, batch_size) # register model model_name = "w.b" model_length = num_features ps_client.register_model(t_client, worker_index, model_name, model_length, num_worker) ps_client.exist_model(t_client, model_name) print("register and check model >>> name = {}, length = {}".format( model_name, model_length)) # Training the Model train_start = time.time() iter_counter = 0 # Training the Model for epoch in range(num_epochs): epoch_start = time.time() num_batches = math.floor(len(train_set) / batch_size) print(f"worker {worker_index} epoch {epoch}") for batch_idx in range(num_batches): batch_start = time.time() # pull latest model ps_client.can_pull(t_client, model_name, iter_counter, worker_index) latest_model = ps_client.pull_model(t_client, model_name, iter_counter, worker_index) svm.weights = torch.from_numpy(latest_model).reshape( num_features, 1) batch_ins, batch_label = svm.next_batch(batch_idx) acc = svm.one_epoch(batch_idx, epoch) compute_end = time.time() sync_start = time.time() w_update = svm.weights - latest_model ps_client.can_push(t_client, model_name, iter_counter, worker_index) ps_client.push_update(t_client, model_name, w_update, learning_rate, iter_counter, worker_index) ps_client.can_pull(t_client, model_name, iter_counter + 1, worker_index) # sync all workers sync_time = time.time() - sync_start print( 'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, train acc: %.4f, epoch cost %.4f, ' 'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s' % (epoch + 1, NUM_EPOCHS, batch_idx + 1, len(train_indices) / batch_size, time.time() - train_start, acc, time.time() - epoch_start, time.time() - batch_start, compute_end - batch_start, sync_time)) iter_counter += 1 val_acc = svm.evaluate() print("Epoch takes {}s, validation accuracy: {}".format( time.time() - epoch_start, val_acc))
#!/usr/bin/env python import sys sys.path.append("gen-py") from thrift.transport import TTransport from thrift.transport import TSocket from thrift.protocol import TBinaryProtocol from risk import Rating from risk.ttypes import * transport = TTransport.TBufferedTransport(TSocket.TSocket("localhost", "9090")) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = Rating.Client(protocol) transport.open() print client.piotroski("JAVA") print client.altman_z("GOOG")
def handler(event, context): start_time = time.time() bucket = event['data_bucket'] worker_index = event['rank'] num_worker = event['num_workers'] key = event['key'] print('bucket = {}'.format(bucket)) print('number of workers = {}'.format(num_worker)) print('worker index = {}'.format(worker_index)) # Set thrift connection # Make socket transport = TSocket.TSocket(constants.HOST, constants.PORT) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder t_client = ParameterServer.Client(protocol) # Connect! transport.open() # test thrift connection ps_client.ping(t_client) print("create and ping thrift server >>> HOST = {}, PORT = {}".format( constants.HOST, constants.PORT)) #bucket = "cifar10dataset" print('data_bucket = {}\n worker_index:{}\n num_worker:{}\n key:{}'.format( bucket, worker_index, num_worker, key)) # read file from s3 readS3_start = time.time() train_path = download_file(bucket, key) trainset = torch.load(train_path) test_path = download_file(bucket, test_file) testset = torch.load(test_path) print("read data cost {} s".format(time.time() - readS3_start)) preprocess_start = time.time() batch_size = 200 train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') device = 'cpu' print("preprocess data cost {} s".format(time.time() - preprocess_start)) model = MobileNet() model = model.to(device) # Loss and Optimizer # Softmax is internally computed. # Set parameters to be updated. criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4) # register model model_name = "mobilenet" parameter_shape = [] parameter_length = [] model_length = 0 for param in model.parameters(): tmp_shape = 1 parameter_shape.append(param.data.numpy().shape) for w in param.data.numpy().shape: tmp_shape *= w parameter_length.append(tmp_shape) model_length += tmp_shape ps_client.register_model(t_client, worker_index, model_name, model_length, num_worker) ps_client.exist_model(t_client, model_name) print("register and check model >>> name = {}, length = {}".format( model_name, model_length)) # Training the Model train_start = time.time() iter_counter = 0 for epoch in range(num_epochs): epoch_start = time.time() model.train() for batch_index, (inputs, targets) in enumerate(train_loader): print("------worker {} epoch {} batch {}------".format( worker_index, epoch, batch_index)) batch_start = time.time() # pull latest model ps_client.can_pull(t_client, model_name, iter_counter, worker_index) latest_model = ps_client.pull_model(t_client, model_name, iter_counter, worker_index) pos = 0 for layer_index, param in enumerate(model.parameters()): param.data = Variable( torch.from_numpy( np.asarray(latest_model[pos:pos + parameter_length[layer_index]], dtype=np.float32).reshape( parameter_shape[layer_index]))) pos += parameter_length[layer_index] # Forward + Backward + Optimize inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() # flatten and concat gradients of weight and bias param_grad = np.zeros((1)) for param in model.parameters(): #print("shape of layer = {}".format(param.data.numpy().flatten().shape)) param_grad = np.concatenate( (param_grad, param.data.numpy().flatten())) param_grad = np.delete(param_grad, 0) print("model_length = {}".format(param_grad.shape)) # push gradient to PS sync_start = time.time() print( ps_client.can_push(t_client, model_name, iter_counter, worker_index)) print( ps_client.push_grad(t_client, model_name, param_grad, learning_rate, iter_counter, worker_index)) print( ps_client.can_pull(t_client, model_name, iter_counter + 1, worker_index)) # sync all workers sync_time = time.time() - sync_start print( 'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, ' 'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s' % (epoch + 1, num_epochs, batch_index + 1, len(train_indices) / batch_size, time.time() - train_start, loss.data, time.time() - epoch_start, time.time() - batch_start, cal_time, sync_time)) iter_counter += 1 test(epoch, model, test_loader, criterion, device) optimizer.step()
def vector_addition_dfe(size, scalar, in_a, in_b): """VectorAddition DFE implementation.""" try: start_time = time.time() # Make socket socket = TSocket.TSocket('localhost', 9090) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(socket) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = VectorAdditionService.Client(protocol) print('Creating a client:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Connect! start_time = time.time() transport.open() print('Opening connection:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Initialize maxfile start_time = time.time() max_file = client.VectorAddition_init() print('Initializing maxfile:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Load DFE start_time = time.time() max_engine = client.max_load(max_file, '*') print('Loading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time)) # Allocate and send input streams to server start_time = time.time() address_in_a = client.malloc_int32_t(size) client.send_data_int32_t(address_in_a, in_a) address_in_b = client.malloc_int32_t(size) client.send_data_int32_t(address_in_b, in_b) print('Sending input data:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Allocate memory for output stream on server start_time = time.time() address_data_out = client.malloc_int32_t(size) print('Allocating memory for output stream on server:\t%.5lfs' % (time.time() - start_time)) # Write vector a to LMem start_time = time.time() action = client.max_actions_init(max_file, "writeLMem") client.max_set_param_uint64t(action, "address", 0) client.max_set_param_uint64t(action, "nbytes", size * 4) client.max_queue_input(action, "cpu_to_lmem", address_in_a, size * 4) client.max_run(max_engine, action) print('Writing to LMem:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Add two vectors and a scalar start_time = time.time() action = client.max_actions_init(max_file, "default") client.max_set_param_uint64t(action, "N", size) client.max_set_param_uint64t(action, "A", scalar) client.max_queue_input(action, "y", address_in_b, size * 4) client.max_queue_output(action, "s", address_data_out, size * 4) client.max_run(max_engine, action) print('Vector addition time:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Unload DFE start_time = time.time() client.max_unload(max_engine) print('Unloading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time)) # Get output stream from server start_time = time.time() data_out = client.receive_data_int32_t(address_data_out, size) print('Getting output stream:\t(size = %d bit)\t%.5lfs' % ((size * 32), (time.time() - start_time))) # Free allocated memory for streams on server start_time = time.time() client.free(address_in_a) client.free(address_in_b) client.free(address_data_out) client.free(action) print('Freeing allocated memory for streams on server:\t%.5lfs' % (time.time() - start_time)) # Free allocated maxfile data start_time = time.time() client.VectorAddition_free() print('Freeing allocated maxfile data:\t\t\t%.5lfs' % (time.time() - start_time)) # Close! start_time = time.time() transport.close() print('Closing connection:\t\t\t\t%.5lfs' % (time.time() - start_time)) except Thrift.TException, thrift_exceptiion: print '%s' % (thrift_exceptiion.message) sys.exit(-1)
def __init__(self, hostname, port): self.transport = TTransport.TBufferedTransport( TSocket.TSocket(hostname, port)) self.protocol = TBinaryProtocol.TBinaryProtocol(self.transport) self.client = Hbase.Client(self.protocol) self.transport.open()
def execute_hurdle(make_plot=False, result_file="results.json", host='127.0.0.1', port=9090, seed=None, initial_state=None, num_trials=10, num_rounds=30000, scoring_rounds=1000, test_label="team"): num_states = 10 avg_score_threshold = 2.0 trial_pass_threshold = 6 expected = expected_random_score(num_states) print("expected score of random guesser is {}".format(expected * num_rounds)) print("score required to pass a trial is {}".format(avg_score_threshold * num_rounds)) print("Number of trials passed to pass Hurdle 3 is {} out of {}".format( trial_pass_threshold, num_trials)) # Make socket transport = TSocket.TSocket(host, port) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = Hurdle3Execution.Client(protocol) # Connect! transport.open() results = {"trials": {}, "test_label": test_label} # set up a dedicated random number generator # for this object to guarantee repeatability # of solution evaluation without forcing # each trial to use the same seed. # This ensures the same set of trial seeds will be used # when using the same top level seed rng = np.random.RandomState(seed) for t in range(num_trials): # generate a unique seed per trial trial_seed = rng.randint(0, 0xffffffff) if initial_state is None: trial_initial_state = rng.choice(range(num_states)) else: trial_initial_state = initial_state # create a new probabilistic state machine with potentaiily # a new initial state and seed # at the start of each trial psm = PSM(num_states, trial_initial_state, trial_seed) print("starting trial {} of {} with trial seed {}".format( t, num_trials, trial_seed)) # run the trial and store the trial results trial_results = run_trial(t, num_rounds, scoring_rounds, avg_score_threshold, client, psm) # add seed and initial state to trial_results trial_results["seed"] = trial_seed trial_results["initial_state"] = trial_initial_state results["trials"][t] = trial_results # count the number of trials that passed trials_passed = sum( [results["trials"][i]["trial_pass"] for i in range(num_trials)]) print( "Number of trials passed: {} Number of trials passed required to pass Hurdle 3: {}" .format(trials_passed, trial_pass_threshold)) hurdle_pass = trials_passed >= trial_pass_threshold print("Hurdle 3 Passed? {}".format(hurdle_pass)) results["main_seed"] = seed results["num_trials"] = num_trials results["num_states"] = num_states results["trials_passed"] = trials_passed results["trial_pass_threshold"] = trial_pass_threshold results["hurdle_pass"] = hurdle_pass with open(result_file, 'w') as f: f.write(json.dumps(results)) print("Writing results to file: {}".format(result_file)) #print("Results file: {}".format(results)) client.stop()
def __init__(self): transport = TSocket.TSocket(config.IP, config.PORT) self.transport = TTransport.TBufferedTransport(transport) protocol = TBinaryProtocol.TBinaryProtocol(self.transport) self.client = ClassificationService.Client(protocol) self.transport.open()
def __init__(self, host, port): self.transport = TSocket.TSocket(host, port) self.transport = TTransport.TBufferedTransport(self.transport) self.protocol = TBinaryProtocol.TBinaryProtocol(self.transport) self.transport.open() self.client = GraphService.Client(self.protocol)
def main(argv): global cursor global ifprop global vxtunprop global ret global vxret global collector_list collSessRet = 100 collSessId = '' connection = pymonetdb.connect(username="******", password="******", hostname="localhost", database="voc") connection.set_autocommit(True) cursor = connection.cursor() cursor.arraysize = 100 # Make socket transport = TSocket.TSocket('localhost', 34532) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = aev_config.Client(protocol) # Connect! transport.open() client.ping() #print('ping()') try: #DOCUMENTATION USAGE COMPRESSED SECTION# ''' //Documentation //Usage of the parsing of getopt //Sample run of the function: ****************** import getopt import sys version = "1.0" verbose = False output_filename = 'default.out' first_arg="" second_arg="" third_arg="" print ("arguments : " + str(sys.argv[1:])) options, remainder = getopt.getopt(sys.argv[1:], "o:", ['output=', 'verbose', 'version=', 'i1=', 'i2=', 'i3=' ]) print ("OPTIONS : "+ str(options)) for opt, arg in options: if opt in ('-o', '--output'): output_filename = arg elif opt in ('-v', '--verbose'): verbose = True elif opt == '--version': version = arg elif opt in ('--i1'): first_arg = arg elif opt in ('--i2'): second_arg= arg elif opt in ('--i3'): third_arg = arg; print ('VERSION :'+ version ) print ('VERBOSE :' ) print(verbose ) print ('OUTPUT :'+ output_filename ) print ('REMAINING :'+ str(remainder) ) print ('i1' + first_arg) print('i2' + second_arg) print('i3' + third_arg) ****************** //The output tells us about the must or mandatory arguments starts with '-' and optional '--' and the second list is about some other arguments passed without - or -- usage. ''' opts, args = getopt.getopt(argv, "m:c:n", [ "i1=", "i2=", "i3=", "i4=", "i5=", "i6=", "i7=", "i8=", "i9=", "i10=" ]) #opts, args = getopt.getopt(argv,"hi:j:m:c:",["ifile=","ofile="]) except getopt.GetoptError: print 'test.py -i <inputfile> -o <outputfile>' sys.exit(2) no = 0 ''' All the params passed by the clish are in the local scope of main function. ''' param1 = param2 = param3 = param4 = param5 = param6 = param7 = param8 = param9 = param10 = None mode = command = None for opt, arg in opts: if opt in ("--i1"): param1 = arg elif opt in ("--i2"): param2 = arg elif opt in ("--i3"): param3 = arg elif opt in ("--i4"): param4 = arg elif opt in ("--i5"): param5 = arg elif opt in ("--i6"): param6 = arg elif opt in ("--i7"): param7 = arg elif opt in ("--i8"): param8 = arg elif opt in ("--i9"): param9 = arg print("\n" + param9) elif opt in ("--i10"): param10 = arg elif opt in ("-m"): mode = arg elif opt in ("-c"): command = arg elif opt in ("-n"): no = arg if mode == 'enable': if command == 'show_iface': selectstring = "SELECT * FROM if_table" cursor.execute(selectstring) print cursor.fetchmany() elif mode == 'config': ''' If the mode is config and command is collsession, we call collector_session_check(sessionId) ''' if command == "collsession": #Checked as int at CLI only collSessRet = collector_session_check(param1) elif command == "delTunnel": vxtunprop = aev_vxlan_tunnel_prop() delTunnelName = param1 ##Send command to delete both table rows in case the tunnel name exists deleteRet = deleteBothTnlTbl(delTunnelName) if deleteRet == True: #send delete to the BCM server client.aev_if_prop_delete(1, vxtunprop) elif deleteRet == False: print( '\n Such a tunnel name do not exists, please check the input' ) elif command == "delcollsession": vxtunprop = aev_vxlan_tunnel_prop() delSessId = param1 if (bool(str(delSessId).strip()) ) == True or str(delSessId).strip() != '': print('\n The session id is not empty') #Acts to be performed: #1.Delete the session id row #2.Delete the session id from the tunnel table #3.Send delete to the board for aev print('\n The result of the delete from collector table is \n') print(delSessTabRow(delSessId)) print( '\n The result of update the session id from the tunnel table is \n' ) tunnelName = deleteSessIdTnlTbl(delSessId) if tunnelName != 0: print('\n The result of delete sent to the BCM board \n') vxtunprop.tunnel_name = tunnelName print('\n Send delete to the BCM \n') example = client.aev_if_prop_delete(1, vxtunprop) print(example) else: print('\n the tunnelname was not found') elif command == "iface": ret = set_config_mode(param1) elif command == "tunnel": tunName = param2 tunType = param1 if (bool(str(tunName).strip()) ) == False or str(tunName).strip() == '': print('\n The tunnel name is not as per the sepcification') flag = crupdateTnlTab(tunName, tunType) if flag == 1: print('\n New Tunnel row has been created and type inserted ') elif flag == 2: print('\n Tunnel row has been modified with new type') elif flag == 3: print('\n Tunnel row has not been modified and type exists') elif mode == "config-iface": ifprop = aev_if_prop() if command == 'speed': ret = set_speed(param1, param2) elif command == 'shutdown': ret = set_shutdown(param1) elif command == 'noshutdown': ret = set_noshutdown(param1) elif command == 'mtu': ret = set_mtu(param1, param2) elif command == 'autonego': ret = set_autonego(param1, param2) if ret == 0: example = client.aev_if_prop_update(1, ifprop) print(example) elif mode == "config_collector_sess": if command == "dest_tunnel": collSessId = param2 #Int tunName = param1 #String srcVlan = 0 #Int destTunnelRet = dstTunnelName(tunName, srcVlan, collSessId) if destTunnelRet == 0: print("Tunnel port_access is updated in VXLAN tunnel table") elif destTunnelRet == 1: print("Tunnel doesn't exist but the VLAN is updated") elif destTunnelRet == 2: print( "\n The tunnel cannot be set as the source interface has not been found in the session table row" ) elif destTunnelRet == -1: print("\n Tunnel exists but currently in use by other session") elif command == "sourceint": vxtunprop = aev_vxlan_tunnel_prop() collSessId = param2 #integer #Check print(type(collSessId)) sourceInt = param1 #subcommand string #Check valid source interface or not sourceIntRet = checkValidInterface(sourceInt) #If the sourceIntRet is true as it should receive 1 if sourceIntRet == True: #If the source interface is valid and free, try to update into collector table intSessRel = modifyNewCollectorTable(sourceInt, collSessId) if intSessRel == 0 or intSessRel == 1: print( "\n Already the interface exists with this session or the interface is assigned" ) #Check destination tunnel is set or not, if it is we need to chec the ifindex or update it dstTunnelTestRet = dstTunnelTest(sourceInt, collSessId) if dstTunnelTestRet == True: #True will be sent only if the vxtunprop is set by the method. example = client.aev_vxlan_tunnel_create(1, vxtunprop) print(example) else: print( '\nEither the interface is pushed inside the vxlan_tnl_tab or the same interface was sent by the user' ) elif intSessRel == -1: print("\n Can't help, somebody else is using it") elif sourceIntRet == False: print("Not a valid interface selected") elif mode == "config-tunnel": vxtunprop = aev_vxlan_tunnel_prop() if command == "tunnelprop": param = list() param.append(param1) param.append(param2) param.append(param3) param.append(param4) param.append(param5) param.append(param6) param.append(param7) param.append(param8) param.append(param9) param.append(param10) for vxretobj in updateVxlanTab(param): print(vxretobj) # 2 defines something is wrong with the provided parameters # 1 means all of them are set, need to send create to the BCM # -1 means we need to send delete to the BCM only with a tunnel name # 0 means the source interface is not set till now, so update the entries if they are working well if vxretobj == 0: print( '\n The values except the source interface has been updated into vx_tnl_tbl' ) if vxretobj == -1: print('\n Send delete to the BCM') ret = client.aev_if_prop_delete(1, vxtunprop) if ret == True: deleteQueryString = "DELETE FROM vxlan_tnl_table WHERE vxlan_tnl_table.tnl_name='%s'" % vxtunprop.tunnel_name cursor.execute(deleteQueryString) if vxretobj == 1: example = client.aev_if_prop_create(1, vxtunprop) print(example) if vxretobj == 2: print("\n Something is wrong with provided parameters") connection.close() transport.close()
def simple_dfe(size, data_in): """Simple DFE implementation.""" try: start_time = time.time() # Make socket socket = TSocket.TSocket('localhost', 9090) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(socket) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = SimpleService.Client(protocol) print('Creating a client:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Connect! start_time = time.time() transport.open() print('Opening connection:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Initialize maxfile start_time = time.time() max_file = client.Simple_init() print('Initializing maxfile:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Load DFE start_time = time.time() max_engine = client.max_load(max_file, '*') print('Loading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time)) # Allocate and send input streams to server start_time = time.time() address_data_in = client.malloc_float(size) client.send_data_float(address_data_in, data_in) print('Sending input data:\t\t\t\t%.5lfs' % (time.time() - start_time)) # Allocate memory for output stream on server start_time = time.time() address_data_out = client.malloc_float(size) print('Allocating memory for output stream on server:\t%.5lfs' % (time.time() - start_time)) # Action default start_time = time.time() actions = Simple_actions_t_struct(size, address_data_in, address_data_out) address_actions = client.send_Simple_actions_t(actions) client.Simple_run(max_engine, address_actions) print('Simple time:\t\t\t\t\t%.5lfs' % (time.time() - start_time)) # Unload DFE start_time = time.time() client.max_unload(max_engine) print('Unloading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time)) # Get output stream from server start_time = time.time() data_out = client.receive_data_float(address_data_out, size) print('Getting output stream:\t(size = %d bit)\t%.5lfs' % ((size * 32), (time.time() - start_time))) # Free allocated memory for streams on server start_time = time.time() client.free(address_data_in) client.free(address_data_out) client.free(address_actions) print('Freeing allocated memory for streams on server:\t%.5lfs' % (time.time() - start_time)) # Free allocated maxfile data start_time = time.time() client.Simple_free() print('Freeing allocated maxfile data:\t\t\t%.5lfs' % (time.time() - start_time)) # Close! start_time = time.time() transport.close() print('Closing connection:\t\t\t\t%.5lfs' % (time.time() - start_time)) except Thrift.TException, thrift_exceptiion: print '%s' % (thrift_exceptiion.message) sys.exit(-1)
from thrift.protocol import TBinaryProtocol from thrift.transport import TSocket, TTransport from thriftDemo.service.hello import HelloService if __name__ == '__main__': server = TSocket.TSocket("localhost", 12580) transport = TTransport.TBufferedTransport(server) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = HelloService.Client(protocol) transport.open() pong = client.ping("hello, world!") print("pong %s" % pong) transport.close()
def handler(event, context): start_time = time.time() # dataset setting train_file = event['train_file'] test_file = event['test_file'] data_bucket = event['data_bucket'] n_features = event['n_features'] n_classes = event['n_classes'] n_workers = event['n_workers'] worker_index = event['worker_index'] cp_bucket = event['cp_bucket'] # ps setting host = event['host'] port = event['port'] # training setting model_name = event['model'] optim = event['optim'] sync_mode = event['sync_mode'] assert model_name.lower() in MLModel.Deep_Models assert optim.lower() in Optimization.Grad_Avg assert sync_mode.lower() in Synchronization.Reduce # hyper-parameter learning_rate = event['lr'] batch_size = event['batch_size'] n_epochs = event['n_epochs'] start_epoch = event['start_epoch'] run_epochs = event['run_epochs'] function_name = event['function_name'] print('data bucket = {}'.format(data_bucket)) print("train file = {}".format(train_file)) print("test file = {}".format(test_file)) print('number of workers = {}'.format(n_workers)) print('worker index = {}'.format(worker_index)) print('model = {}'.format(model_name)) print('optimization = {}'.format(optim)) print('sync mode = {}'.format(sync_mode)) print('start epoch = {}'.format(start_epoch)) print('run epochs = {}'.format(run_epochs)) print('host = {}'.format(host)) print('port = {}'.format(port)) print("Run function {}, round: {}/{}, epoch: {}/{} to {}/{}".format( function_name, int(start_epoch / run_epochs) + 1, math.ceil(n_epochs / run_epochs), start_epoch + 1, n_epochs, start_epoch + run_epochs, n_epochs)) # download file from s3 storage = S3Storage() local_dir = "/tmp" read_start = time.time() storage.download(data_bucket, train_file, os.path.join(local_dir, train_file)) storage.download(data_bucket, test_file, os.path.join(local_dir, test_file)) print("download file from s3 cost {} s".format(time.time() - read_start)) train_set = torch.load(os.path.join(local_dir, train_file)) test_set = torch.load(os.path.join(local_dir, test_file)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) n_train_batch = len(train_loader) test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') print("read data cost {} s".format(time.time() - read_start)) random_seed = 100 torch.manual_seed(random_seed) device = 'cpu' model = deep_models.get_models(model_name).to(device) optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # load checkpoint model if it is not the first round if start_epoch != 0: checked_file = 'checkpoint_{}.pt'.format(start_epoch - 1) storage.download(cp_bucket, checked_file, os.path.join(local_dir, checked_file)) checkpoint_model = torch.load(os.path.join(local_dir, checked_file)) model.load_state_dict(checkpoint_model['model_state_dict']) optimizer.load_state_dict(checkpoint_model['optimizer_state_dict']) print("load checkpoint model at epoch {}".format(start_epoch - 1)) # Set thrift connection # Make socket transport = TSocket.TSocket(host, port) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder t_client = ParameterServer.Client(protocol) # Connect! transport.open() # test thrift connection ps_client.ping(t_client) print("create and ping thrift server >>> HOST = {}, PORT = {}".format( host, port)) # register model parameter_shape = [] parameter_length = [] model_length = 0 for param in model.parameters(): tmp_shape = 1 parameter_shape.append(param.data.numpy().shape) for w in param.data.numpy().shape: tmp_shape *= w parameter_length.append(tmp_shape) model_length += tmp_shape ps_client.register_model(t_client, worker_index, model_name, model_length, n_workers) ps_client.exist_model(t_client, model_name) print("register and check model >>> name = {}, length = {}".format( model_name, model_length)) # Training the Model train_start = time.time() iter_counter = 0 for epoch in range(start_epoch, min(start_epoch + run_epochs, n_epochs)): model.train() epoch_start = time.time() train_acc = Accuracy() train_loss = Average() for batch_idx, (inputs, targets) in enumerate(train_loader): batch_start = time.time() batch_cal_time = 0 batch_comm_time = 0 # pull latest model ps_client.can_pull(t_client, model_name, iter_counter, worker_index) latest_model = ps_client.pull_model(t_client, model_name, iter_counter, worker_index) pos = 0 for layer_index, param in enumerate(model.parameters()): param.data = Variable( torch.from_numpy( np.asarray(latest_model[pos:pos + parameter_length[layer_index]], dtype=np.float32).reshape( parameter_shape[layer_index]))) pos += parameter_length[layer_index] batch_comm_time += time.time() - batch_start batch_cal_start = time.time() outputs = model(inputs) loss = F.cross_entropy(outputs, targets) optimizer.zero_grad() loss.backward() # flatten and concat gradients of weight and bias param_grad = np.zeros((1)) for param in model.parameters(): # print("shape of layer = {}".format(param.data.numpy().flatten().shape)) param_grad = np.concatenate( (param_grad, param.data.numpy().flatten())) param_grad = np.delete(param_grad, 0) #print("model_length = {}".format(param_grad.shape)) batch_cal_time += time.time() - batch_cal_start # push gradient to PS batch_push_start = time.time() ps_client.can_push(t_client, model_name, iter_counter, worker_index) ps_client.push_grad(t_client, model_name, param_grad, -1. * learning_rate / n_workers, iter_counter, worker_index) ps_client.can_pull(t_client, model_name, iter_counter + 1, worker_index) # sync all workers batch_comm_time += time.time() - batch_push_start train_acc.update(outputs, targets) train_loss.update(loss.item(), inputs.size(0)) optimizer.step() iter_counter += 1 if batch_idx % 10 == 0: print( 'Epoch: [%d/%d], Batch: [%d/%d], Time: %.4f, Loss: %.4f, epoch cost %.4f, ' 'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s' % (epoch + 1, n_epochs, batch_idx + 1, n_train_batch, time.time() - train_start, loss.item(), time.time() - epoch_start, time.time() - batch_start, batch_cal_time, batch_comm_time)) test_loss, test_acc = test(epoch, model, test_loader) print( 'Epoch: {}/{},'.format(epoch + 1, n_epochs), 'train loss: {},'.format(train_loss), 'train acc: {},'.format(train_acc), 'test loss: {},'.format(test_loss), 'test acc: {}.'.format(test_acc), ) # training is not finished yet, invoke next round if epoch < n_epochs - 1: checkpoint_model = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss.average } checked_file = 'checkpoint_{}.pt'.format(epoch) if worker_index == 0: torch.save(checkpoint_model, os.path.join(local_dir, checked_file)) storage.upload(cp_bucket, checked_file, os.path.join(local_dir, checked_file)) print("checkpoint model at epoch {} saved!".format(epoch)) print( "Invoking the next round of functions. round: {}/{}, start epoch: {}, run epoch: {}" .format( int((epoch + 1) / run_epochs) + 1, math.ceil(n_epochs / run_epochs), epoch + 1, run_epochs)) lambda_client = boto3.client('lambda') payload = { 'train_file': event['train_file'], 'test_file': event['test_file'], 'data_bucket': event['data_bucket'], 'n_features': event['n_features'], 'n_classes': event['n_classes'], 'n_workers': event['n_workers'], 'worker_index': event['worker_index'], 'cp_bucket': event['cp_bucket'], 'host': event['host'], 'port': event['port'], 'model': event['model'], 'optim': event['optim'], 'sync_mode': event['sync_mode'], 'lr': event['lr'], 'batch_size': event['batch_size'], 'n_epochs': event['n_epochs'], 'start_epoch': epoch + 1, 'run_epochs': event['run_epochs'], 'function_name': event['function_name'] } lambda_client.invoke(FunctionName=function_name, InvocationType='Event', Payload=json.dumps(payload)) end_time = time.time() print("Elapsed time = {} s".format(end_time - start_time))
def findHeavyHitters(table, today=datetime.date.today(), verbose=False): """ Find heavy hitters in the given traffic (table) and store the results in the 'suspiciousheavyhitters' Hive table. """ histNbDay = 15 date = "%d%02d%02d" % (today.year, today.month, today.day) dates = list( "%d%02d%02d" % (x.year, x.month, x.day) for x in pd.date_range(today - datetime.timedelta(histNbDay), today - datetime.timedelta(1))) table = scrub(table) ## set some variables regarding the input data if table.startswith("netflow"): dataType = "netflow" endpointTypes = [("dstip", "da"), ("srcip", "sa")] req0 = "select {endpoint}, sum(ipkt) nbpkt, sum(ibyt) nbbyte from {table} where dt=%s group by {endpoint}" req1 = "select {genericLabel}, avg(nbpkt) as avgpkt, stddev_samp(nbpkt) as stdpkt, avg(nbbyt) as avgbyt, stddev_samp(nbbyt) as stdbyt from(select {endpointType} as {genericLabel}, dt, sum(ipkt) as nbpkt, sum(ibyt) as nbbyt from {table} where {endpointType} IN ({suspiciousIP}) and dt IN ({dates}) group by {endpointType}, dt order by {endpointType}, dt) group by {genericLabel}" elif table.startswith("sflow"): dataType = "sflow" endpointTypes = [("dstip", "dstip"), ("srcip", "srcip"), ("dstip", "dstip6"), ("srcip", "srcip6")] req0 = "select {endpoint}, count(*) nbpkt, sum(ipsize) nbbyte from {table} where dt=%s and {endpoint}<>'' group by {endpoint}" req1 = "select {genericLabel}, avg(nbpkt) as avgpkt, stddev_samp(nbpkt) as stdpkt, avg(nbbyt) as avgbyt, stddev_samp(nbbyt) as stdbyt from(select {endpointType} as {genericLabel}, dt, count(*) as nbpkt, sum(ipsize) as nbbyt from {table} where {endpointType} IN ({suspiciousIP}) and dt IN ({dates}) group by {endpointType}, dt order by {endpointType}, dt) group by {genericLabel}" else: sys.stderr.write("Data type unknown!") sys.exit(-1) outputFile = open( "%s/suspiciousheavyhitters_%s_%s.txt" % (outputDirectory, table, date), "w") cursor = presto.connect('localhost').cursor() for genericLabel, endpointType in endpointTypes: if verbose: sys.stdout.write("Looking for %s heavy hitters... (%s,%s)\n" % (date, table, genericLabel)) suspiciousIP = set() # get today's data formatedReq = req0.format(endpoint=endpointType, table=table) cursor.execute(formatedReq, [date]) res = cursor.fetchall() if len(res) == 0: continue data = pd.DataFrame(res, columns=[genericLabel, "nbpkt", "nbbyt"]) data.index = data.pop(genericLabel) # find today's heavy hitter for aggType in ["nbpkt", "nbbyt"]: suspiciousIP.update( data.ix[data[aggType] > data[aggType].mean() + 3 * data[aggType].std()].index.tolist()) # check in past data if they had similar behavior if verbose: sys.stdout.write("Retrieve past data...\n") suspiciousIP = list(suspiciousIP) for i in range(len(suspiciousIP))[::100]: susIP = suspiciousIP[i:i + 100] formatedReq1 = req1.format( genericLabel=genericLabel, endpointType=endpointType, table=table, suspiciousIP=str.translate(str(list(susIP)), None, "u[]"), dates=str.translate(str(dates), None, "u[]")) cursor.execute(formatedReq1) res = cursor.fetchall() if verbose: sys.stdout.write("Register suspicious IPs...\n") for ip, avgpkt, stdpkt, avgbyt, stdbyt in res: currData = data.ix[ip] if genericLabel == "dstip": dstip = ip srcip = "" else: dstip = "" srcip = ip try: if currData["nbpkt"] > avgpkt + 3 * stdpkt or currData[ "nbbyt"] > avgbyt + 3 * stdbyt: outputFile.write( "%s\t%s\t%s\t%s\t%s\t\n" % (srcip, dstip, currData["nbpkt"], currData["nbbyt"], confidence(currData["nbpkt"], avgpkt, stdpkt, currData["nbbyt"], avgbyt, stdbyt))) except TypeError: if verbose: sys.stdout.write( "!!Warning!! no past data for %s (avgpkt=%s, stdpkt=%s, avgbyt=%s, stdbyt=%s)\n" % (ip, avgpkt, stdpkt, avgbyt, stdbyt)) outputFile.write("%s\t%s\t%s\t%s\t%s\t\n" % (srcip, dstip, currData["nbpkt"], currData["nbbyt"], "MED")) continue outputFile.close() # Store results in Hive try: transport = TSocket.TSocket('localhost', 10000) transport = TTransport.TBufferedTransport(transport) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = ThriftHive.Client(protocol) transport.open() client.execute( "create table if not exists suspiciousheavyhitters (srcip string, dstip string, pkt bigint, byte bigint, confidence string) partitioned by(dt string, dataSrc string) row format delimited fields terminated by '\t'" ) client.execute( "load data local inpath '{dir}/suspiciousheavyhitters_{table}_{date}.txt' overwrite into table suspiciousheavyhitters partition (dt='{date}', dataSrc='{table}')" .format(table=table, date=date, dir=outputDirectory)) transport.close() except Thrift.TException, tx: sys.stderr.write('%s\n' % (tx.message))
def connect_thrift_node(self, node_to_connect): """ attempt to connect a node through thrift networking services """ connection_successful = False if not node_to_connect.connected: try: if node_to_connect not in self.connections: # and \ # peer.connection_attempts < MAX_CONNECTION_ATTEMPTS and \ # len(self.peers) < self.max_outbound_connections: logger().info('attempting connect_thrift_node %s:%s', node_to_connect.host, node_to_connect.port) pass_phrase = str(uuid.uuid4()) # Make socket transport = TSocket.TSocket(node_to_connect.host, int(node_to_connect.port)) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) # Create a client to use the protocol encoder client = BlockchainService.Client(protocol) # Connect transport.open() logger().info('about to register') connection_successful = client.register_node( self.this_node, pass_phrase) logger().info('transport open to node %s', node_to_connect.node_id) if connection_successful: node_to_connect.connected = True node_to_connect.pass_phrase = pass_phrase node_to_connect.transport = transport node_to_connect.client = client logger().info( '%s accepted outbound connection request.', node_to_connect.node_id) logger().info('node owner: %s', node_to_connect.owner) logger().info('phases provided: %s', '{:05b}'.format(node_to_connect.phases)) else: try: net_dao.update_con_attempts( node_to_connect ) # incrementing connection attempts on fail except Exception as ex: template = "An exception of type {0} occured. Arguments:\n{1!r}" message = template.format( type(ex).__name__, ex.args) logger().warning(message) transport.close print(node_to_connect.node_id + ' rejected outbound connection request.') except Exception as ex: if not connection_successful: net_dao.update_con_attempts(node_to_connect) template = "An exception of type {0} occured. Arguments:\n{1!r}" message = template.format(type(ex).__name__, ex.args) logger().warning(message) finally: logger().info('connect_thrift_node %s', str(connection_successful)) return connection_successful
import configparser import os import os.path # Read in configs from config file conffile = sys.argv[1] config = configparser.ConfigParser() config.sections() config.read(conffile) host = config['DEFAULT']['Host'] port = config['DEFAULT']['Port'] # Connect to HBase Thrift server: This creates the socket transport and line protocol and # allows the Thrift client to connect and talk to the Thrift server. transport = TTransport.TBufferedTransport(TSocket.TSocket(host, port)) protocol = TBinaryProtocol.TBinaryProtocolAccelerated(transport) # Create and open the client connection: create the Client object you will be using to # interact with HBase. From this client object, you will issue all your Gets and Puts. client = Hbase.Client(protocol) transport.open() ######################################### # Define Table Name and Column Family Name Here ######################################### tablename = 'haztable5' cfname = 'cf1' tableexists = 0
def do_POST(self): """ Handles POST queries, which are usually Thrift messages. """ client_host, client_port = self.client_address auth_session = self.__check_session_cookie() LOG.debug("%s:%s -- [%s] POST %s", client_host, str(client_port), auth_session.user if auth_session else "Anonymous", self.path) # Create new thrift handler. checker_md_docs = self.server.checker_md_docs checker_md_docs_map = self.server.checker_md_docs_map suppress_handler = self.server.suppress_handler version = self.server.version protocol_factory = TJSONProtocol.TJSONProtocolFactory() input_protocol_factory = protocol_factory output_protocol_factory = protocol_factory itrans = TTransport.TFileObjectTransport(self.rfile) itrans = TTransport.TBufferedTransport( itrans, int(self.headers['Content-Length'])) otrans = TTransport.TMemoryBuffer() iprot = input_protocol_factory.getProtocol(itrans) oprot = output_protocol_factory.getProtocol(otrans) if self.server.manager.is_enabled and \ not self.path.endswith('/Authentication') and \ not auth_session: # Bail out if the user is not authenticated... # This response has the possibility of melting down Thrift clients, # but the user is expected to properly authenticate first. LOG.debug(client_host + ":" + str(client_port) + " Invalid access, credentials not found " + "- session refused.") self.send_error(401) return # Authentication is handled, we may now respond to the user. try: product_endpoint, api_ver, request_endpoint = \ routing.split_client_POST_request(self.path) product = None if product_endpoint: # The current request came through a product route, and not # to the main endpoint. product = self.server.get_product(product_endpoint) self.__check_prod_db(product) version_supported = routing.is_supported_version(api_ver) if version_supported: major_version, _ = version_supported if major_version == 6: if request_endpoint == 'Authentication': auth_handler = AuthHandler_v6( self.server.manager, auth_session, self.server.config_session) processor = AuthAPI_v6.Processor(auth_handler) elif request_endpoint == 'Products': prod_handler = ProductHandler_v6( self.server, auth_session, self.server.config_session, product, version) processor = ProductAPI_v6.Processor(prod_handler) elif request_endpoint == 'CodeCheckerService': # This endpoint is a product's report_server. if not product: error_msg = "Requested CodeCheckerService on a " \ "nonexistent product: '{0}'." \ .format(product_endpoint) LOG.error(error_msg) raise ValueError(error_msg) if product_endpoint: # The current request came through a # product route, and not to the main endpoint. product = self.server.get_product(product_endpoint) self.__check_prod_db(product) acc_handler = ReportHandler_v6( self.server.manager, product.session_factory, product, auth_session, self.server.config_session, checker_md_docs, checker_md_docs_map, suppress_handler, version) processor = ReportAPI_v6.Processor(acc_handler) else: LOG.debug("This API endpoint does not exist.") error_msg = "No API endpoint named '{0}'." \ .format(self.path) raise ValueError(error_msg) else: if request_endpoint == 'Authentication': # API-version checking is supported on the auth endpoint. handler = BadAPIHandler(api_ver) processor = AuthAPI_v6.Processor(handler) else: # Send a custom, but valid Thrift error message to the # client requesting this action. error_msg = "Incompatible client/server API." \ "API versions supported by this server {0}." \ .format(get_version_str()) raise ValueError(error_msg) processor.process(iprot, oprot) result = otrans.getvalue() self.send_response(200) self.send_header("content-type", "application/x-thrift") self.send_header("Content-Length", len(result)) self.end_headers() self.wfile.write(result) return except Exception as exn: # Convert every Exception to the proper format which can be parsed # by the Thrift clients expecting JSON responses. LOG.error(exn.message) import traceback traceback.print_exc() ex = TApplicationException(TApplicationException.INTERNAL_ERROR, exn.message) fname, _, seqid = iprot.readMessageBegin() oprot.writeMessageBegin(fname, TMessageType.EXCEPTION, seqid) ex.write(oprot) oprot.writeMessageEnd() oprot.trans.flush() result = otrans.getvalue() self.send_response(200) self.send_header("content-type", "application/x-thrift") self.send_header("Content-Length", len(result)) self.end_headers() self.wfile.write(result) return
def __init__( self, uri=None, user=None, password=None, host=None, port=6274, dbname=None, protocol='binary', sessionid=None, bin_cert_validate=None, bin_ca_certs=None, ): self.sessionid = None if sessionid is not None: if any([user, password, uri, dbname]): raise TypeError("Cannot specify sessionid with user, password," " dbname, or uri") if uri is not None: if not all([ user is None, password is None, host is None, port == 6274, dbname is None, protocol == 'binary', bin_cert_validate is None, bin_ca_certs is None ]): raise TypeError("Cannot specify both URI and other arguments") user, password, host, port, dbname, protocol, \ bin_cert_validate, bin_ca_certs = _parse_uri(uri) if host is None: raise TypeError("`host` parameter is required.") if protocol != 'binary' and not all( [bin_cert_validate is None, bin_ca_certs is None]): raise TypeError("Cannot specify bin_cert_validate or bin_ca_certs," " without binary protocol") if protocol in ("http", "https"): if not host.startswith(protocol): # the THttpClient expects http[s]://localhost host = '{0}://{1}'.format(protocol, host) transport = THttpClient.THttpClient("{}:{}".format(host, port)) proto = TJSONProtocol.TJSONProtocol(transport) socket = None elif protocol == "binary": if any([bin_cert_validate is not None, bin_ca_certs]): socket = TSSLSocket.TSSLSocket(host, port, validate=(bin_cert_validate), ca_certs=bin_ca_certs) else: socket = TSocket.TSocket(host, port) transport = TTransport.TBufferedTransport(socket) proto = TBinaryProtocol.TBinaryProtocolAccelerated(transport) else: raise ValueError("`protocol` should be one of", " ['http', 'https', 'binary'],", " got {} instead".format(protocol)) self._user = user self._password = password self._host = host self._port = port self._dbname = dbname self._transport = transport self._protocol = protocol self._socket = socket self._closed = 0 self._tdf = None self._rbc = None try: self._transport.open() except TTransportException as e: if e.NOT_OPEN: err = OperationalError("Could not connect to database") raise err from e else: raise self._client = Client(proto) try: # If a sessionid was passed, we should validate it if sessionid: self._session = sessionid self.get_tables() self.sessionid = sessionid else: self._session = self._client.connect(user, password, dbname) except TMapDException as e: raise _translate_exception(e) from e except TTransportException: raise ValueError(f"Connection failed with port {port} and " f"protocol '{protocol}'. Try port 6274 for " "protocol == binary or 6273, 6278 or 443 for " "http[s]") # if OmniSci version <4.6, raise RuntimeError, as data import can be # incorrect for columnar date loads # Caused by https://github.com/omnisci/pymapd/pull/188 semver = self._client.get_version() if Version(semver.split("-")[0]) < Version("4.6"): raise RuntimeError(f"Version {semver} of OmniSci detected. " "Please use pymapd <0.11. See release notes " "for more details.")
class thrift_utils: transport = TSocket.TSocket('192.168.56.1', 9090) transport = TTransport.TBufferedTransport(transport) protocol = TBinaryProtocol.TBinaryProtocol(transport) transport.open()
import sys #Hbase.thrift生成的py文件放在这里 sys.path.append('/usr/local/python2.7.3/lib/python2.7/site-packages/hbase') from thrift import Thrift from thrift.transport import TSocket from thrift.transport import TTransport from thrift.protocol import TBinaryProtocol from hbase import Hbase #如ColumnDescriptor 等在hbase.ttypes中定义 from hbase.ttypes import * # Make socket #此处可以修改地址和端口 transport = TSocket.TSocket('192.168.1.220', 9090) # Buffering is critical. Raw sockets are very slow # 还可以用TFramedTransport,也是高效传输方式 transport = TTransport.TBufferedTransport(transport) # Wrap in a protocol #传输协议和传输过程是分离的,可以支持多协议 protocol = TBinaryProtocol.TBinaryProtocol(transport) #客户端代表一个用户 client = Hbase.Client(protocol) #打开连接 transport.open() #打印表名 print(client.getTableNames())
def main(): # Make socket transport = TSocket.TSocket('localhost', port) base_transport = TSocket.TSocket('localhost', base_port) # Buffering is critical. Raw sockets are very slow transport = TTransport.TBufferedTransport(transport) base_transport = TTransport.TBufferedTransport(base_transport) # Wrap in a protocol protocol = TBinaryProtocol.TBinaryProtocol(transport) base_protocol = TBinaryProtocol.TBinaryProtocol(base_transport) # Create a client to use the protocol encoder client = Drone.Client(protocol) base_client = BaseDoor.Client(base_protocol) # Connect! transport.open() base_transport.open() time.sleep(15) # Prepare doors! print("Opening doors...") base_client.openDoor() print("Clearing existing missions...") client.clear_missions() print("Downloading missions...") client.download_missions() print("Creating coordinate objects...") thrift_coordinate_list = [] altitude = float(args.altitude) for count, coordinate in enumerate(coordinate_list): latlng = coordinate.split(",") lat = latlng[0] lng = latlng[1] if count == 0: first_coordinate = { "latitude": float(lat), "longitude": float(lng), } print("First waypoint is {0},{1}".format( first_coordinate["latitude"], first_coordinate["longitude"])) coordinate_obj = Coordinate(latitude=float(lat), longitude=float(lng), altitude=altitude) thrift_coordinate_list.append(coordinate_obj) print("Sending coordinates to server...") client.add_farm_mission(thrift_coordinate_list) transport.close() base_transport.close() print("Starting in flight status reports...") while (True): transport.open() print("Reporting flight status...") status_obj = client.report_status(int(args.drone_id)) print("Armed: {0}".format(status_obj.armed)) if first_coordinate[ "latitude"] - 0.0001 <= status_obj.latitude <= first_coordinate[ "latitude"] + 0.0001 and first_coordinate[ "longitude"] - 0.0001 <= status_obj.latitude <= first_coordinate[ "longitude"] + 0.0001: print("Requesting camera start...") client.start_camera() if not status_obj.armed: end_mission_status_data = {"status": "Done"} end_drone_status_data = {"status": "Available"} end_mission_post = requests.patch(mission_endpoint, data=end_mission_status_data) end_drone_post = requests.patch(drone_endpoint, data=end_drone_status_data) client.change_mode("RTL") break transport.close() time.sleep(3) # Close! transport.close()