Exemple #1
0
    def do_POST(self):
        """
        Handles POST queries, which are usually Thrift messages.
        """
        protocol_factory = TJSONProtocol.TJSONProtocolFactory()
        input_protocol_factory = protocol_factory
        output_protocol_factory = protocol_factory

        # Get Thrift API function name to print to the log output.
        itrans = TTransport.TFileObjectTransport(self.rfile)
        itrans = TTransport.TBufferedTransport(
            itrans, int(self.headers['Content-Length']))
        iprot = input_protocol_factory.getProtocol(itrans)
        fname, _, _ = iprot.readMessageBegin()

        client_host, client_port, is_ipv6 = \
            RequestHandler._get_client_host_port(self.client_address)
        self.auth_session = self.__check_session_cookie()
        LOG.info("%s:%s -- [%s] POST %s@%s",
                 client_host if not is_ipv6 else '[' + client_host + ']',
                 client_port,
                 self.auth_session.user if self.auth_session else "Anonymous",
                 self.path, fname)

        # Create new thrift handler.
        checker_md_docs = self.server.checker_md_docs
        checker_md_docs_map = self.server.checker_md_docs_map
        version = self.server.version

        cstringio_buf = itrans.cstringio_buf.getvalue()
        itrans = TTransport.TMemoryBuffer(cstringio_buf)
        iprot = input_protocol_factory.getProtocol(itrans)

        otrans = TTransport.TMemoryBuffer()
        oprot = output_protocol_factory.getProtocol(otrans)

        if self.server.manager.is_enabled and \
                not self.path.endswith(('/Authentication',
                                        '/Configuration',
                                        '/ServerInfo')) and \
                not self.auth_session:
            # Bail out if the user is not authenticated...
            # This response has the possibility of melting down Thrift clients,
            # but the user is expected to properly authenticate first.
            LOG.debug(
                "%s:%s Invalid access, credentials not found "
                "- session refused.",
                client_host if not is_ipv6 else '[' + client_host + ']',
                str(client_port))

            self.send_thrift_exception("Error code 401: Unauthorized!", iprot,
                                       oprot, otrans)
            return

        # Authentication is handled, we may now respond to the user.
        try:
            product_endpoint, api_ver, request_endpoint = \
                routing.split_client_POST_request(self.path)
            if product_endpoint is None and api_ver is None and\
                    request_endpoint is None:
                raise Exception("Invalid request endpoint path.")

            product = None
            if product_endpoint:
                # The current request came through a product route, and not
                # to the main endpoint.
                product = self.__check_prod_db(product_endpoint)

            version_supported = routing.is_supported_version(api_ver)
            if version_supported:
                major_version, _ = version_supported

                if major_version == 6:
                    if request_endpoint == 'Authentication':
                        auth_handler = AuthHandler_v6(
                            self.server.manager, self.auth_session,
                            self.server.config_session)
                        processor = AuthAPI_v6.Processor(auth_handler)
                    elif request_endpoint == 'Configuration':
                        conf_handler = ConfigHandler_v6(
                            self.auth_session, self.server.config_session)
                        processor = ConfigAPI_v6.Processor(conf_handler)
                    elif request_endpoint == 'ServerInfo':
                        server_info_handler = ServerInfoHandler_v6(version)
                        processor = ServerInfoAPI_v6.Processor(
                            server_info_handler)
                    elif request_endpoint == 'Products':
                        prod_handler = ProductHandler_v6(
                            self.server, self.auth_session,
                            self.server.config_session, product, version)
                        processor = ProductAPI_v6.Processor(prod_handler)
                    elif request_endpoint == 'CodeCheckerService':
                        # This endpoint is a product's report_server.
                        if not product:
                            error_msg = "Requested CodeCheckerService on a " \
                                         "nonexistent product: '{0}'." \
                                        .format(product_endpoint)
                            LOG.error(error_msg)
                            raise ValueError(error_msg)

                        if product_endpoint:
                            # The current request came through a
                            # product route, and not to the main endpoint.
                            product = self.__check_prod_db(product_endpoint)

                        acc_handler = ReportHandler_v6(
                            self.server.manager, product.session_factory,
                            product, self.auth_session,
                            self.server.config_session, checker_md_docs,
                            checker_md_docs_map, version, self.server.context)
                        processor = ReportAPI_v6.Processor(acc_handler)
                    else:
                        LOG.debug("This API endpoint does not exist.")
                        error_msg = "No API endpoint named '{0}'." \
                                    .format(self.path)
                        raise ValueError(error_msg)

            else:
                error_msg = "The API version you are using is not supported " \
                            "by this server (server API version: {0})!".format(
                                get_version_str())
                self.send_thrift_exception(error_msg, iprot, oprot, otrans)
                return

            processor.process(iprot, oprot)
            result = otrans.getvalue()

            self.send_response(200)
            self.send_header("content-type", "application/x-thrift")
            self.send_header("Content-Length", len(result))
            self.end_headers()
            self.wfile.write(result)
            return

        except Exception as exn:
            LOG.warning(str(exn))
            import traceback
            traceback.print_exc()

            cstringio_buf = itrans.cstringio_buf.getvalue()
            if cstringio_buf:
                itrans = TTransport.TMemoryBuffer(cstringio_buf)
                iprot = input_protocol_factory.getProtocol(itrans)

            self.send_thrift_exception(str(exn), iprot, oprot, otrans)
            return
def pass_through_dfe(size, data_in):
    """PassThrough DFE implementation."""
    try:
        start_time = time.time()

        # Make socket
        socket = TSocket.TSocket('localhost', 9090)

        # Buffering is critical. Raw sockets are very slow
        transport = TTransport.TBufferedTransport(socket)

        # Wrap in a protocol
        protocol = TBinaryProtocol.TBinaryProtocol(transport)

        # Create a client to use the protocol encoder
        client = PassThroughService.Client(protocol)

        print ('Creating a client:\t\t\t\t%.5lfs' %
               (time.time() - start_time))

        # Connect!
        start_time = time.time()
        transport.open()
        print ('Opening connection:\t\t\t\t%.5lfs' %
               (time.time() - start_time))

        # Allocate and send input streams to server
        start_time = time.time()
        address_data_in = client.malloc_float(size)
        client.send_data_float(address_data_in, data_in)
        print ('Sending input data:\t\t\t\t%.5lfs' %
               (time.time() - start_time))

        # Allocate memory for output stream on server
        start_time = time.time()
        address_data_out = client.malloc_float(size)
        print ('Allocating memory for output stream on server:\t%.5lfs'%
               (time.time() - start_time))

        # Action default
        start_time = time.time()
        client.PassThrough(size, address_data_in, address_data_out)
        print ('Pass through time:\t\t\t\t%.5lfs' %
               (time.time() - start_time))

        # Get output stream from server
        start_time = time.time()
        data_out = client.receive_data_float(address_data_out, size)
        print ('Getting output stream:\t(size = %d bit)\t%.5lfs' %
               ((size * 32), (time.time() - start_time)))

        # Free allocated memory for streams on server
        start_time = time.time()
        client.free(address_data_in)
        client.free(address_data_out)
        print ('Freeing allocated memory for streams on server:\t%.5lfs' %
               (time.time() - start_time))

        # Close!
        start_time = time.time()
        transport.close()
        print ('Closing connection:\t\t\t\t%.5lfs' %
               (time.time() - start_time))

    except Thrift.TException, thrift_exceptiion:
        print '%s' % (thrift_exceptiion.message)
        sys.exit(-1)
Exemple #3
0
def main(cfg, reqhandle, resphandle):
    if cfg.unix:
        if cfg.addr == "":
            sys.exit("invalid unix domain socket: {}".format(cfg.addr))
        socket = TSocket.TSocket(unix_socket=cfg.addr)
    else:
        try:
            (host, port) = cfg.addr.rsplit(":", 1)
            if host == "":
                host = "localhost"
            socket = TSocket.TSocket(host=host, port=int(port))
        except ValueError:
            sys.exit("invalid address: {}".format(cfg.addr))

    transport = TRecordingTransport(socket, reqhandle, resphandle)

    if cfg.transport == "framed":
        transport = TTransport.TFramedTransport(transport)
    elif cfg.transport == "unframed":
        transport = TTransport.TBufferedTransport(transport)
    elif cfg.transport == "header":
        transport = THeaderTransport.THeaderTransport(
            transport,
            client_type=THeaderTransport.CLIENT_TYPE.HEADER,
        )

        if cfg.headers is not None:
            pairs = cfg.headers.split(",")
            for p in pairs:
                key, value = p.split("=")
                transport.set_header(key, value)

        if cfg.protocol == "binary":
            transport.set_protocol_id(THeaderTransport.T_BINARY_PROTOCOL)
        elif cfg.protocol == "compact":
            transport.set_protocol_id(THeaderTransport.T_COMPACT_PROTOCOL)
        else:
            sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol))
    else:
        sys.exit("unknown transport {0}".format(cfg.transport))

    transport.open()

    if cfg.protocol == "binary":
        protocol = TBinaryProtocol.TBinaryProtocol(transport)
    elif cfg.protocol == "compact":
        protocol = TCompactProtocol.TCompactProtocol(transport)
    elif cfg.protocol == "json":
        protocol = TJSONProtocol.TJSONProtocol(transport)
    elif cfg.protocol == "finagle":
        protocol = TFinagleProtocol(transport, client_id="thrift-playground")
    else:
        sys.exit("unknown protocol {0}".format(cfg.protocol))

    if cfg.service is not None:
        protocol = TMultiplexedProtocol.TMultiplexedProtocol(protocol, cfg.service)

    client = Example.Client(protocol)

    try:
        if cfg.method == "ping":
            client.ping()
            print("client: pinged")
        elif cfg.method == "poke":
            client.poke()
            print("client: poked")
        elif cfg.method == "add":
            if len(cfg.params) != 2:
                sys.exit("add takes 2 arguments, got: {0}".format(cfg.params))

            a = int(cfg.params[0])
            b = int(cfg.params[1])
            v = client.add(a, b)
            print("client: added {0} + {1} = {2}".format(a, b, v))
        elif cfg.method == "execute":
            param = Param(
                return_fields=cfg.params,
                the_works=TheWorks(
                    field_1=True,
                    field_2=0x7f,
                    field_3=0x7fff,
                    field_4=0x7fffffff,
                    field_5=0x7fffffffffffffff,
                    field_6=-1.5,
                    field_7=u"string is UTF-8: \U0001f60e",
                    field_8=b"binary is bytes: \x80\x7f\x00\x01",
                    field_9={
                        1: "one",
                        2: "two",
                        3: "three"
                    },
                    field_10=[1, 2, 4, 8],
                    field_11=set(["a", "b", "c"]),
                    field_12=False,
                ))

            try:
                result = client.execute(param)
                print("client: executed {0}: {1}".format(param, result))
            except AppException as e:
                print("client: execute failed with IDL Exception: {0}".format(e.why))
        else:
            sys.exit("unknown method {0}".format(cfg.method))
    except Thrift.TApplicationException as e:
        print("client exception: {0}: {1}".format(e.type, e.message))

    if cfg.request is None:
        req = "".join(["%02X " % ord(x) for x in reqhandle.getvalue()]).strip()
        print("request: {}".format(req))
    if cfg.response is None:
        resp = "".join(["%02X " % ord(x) for x in resphandle.getvalue()]).strip()
        print("response: {}".format(resp))

    transport.close()
Exemple #4
0
def getresultbyid(resid):
    socket = TSocket.TSocket(host, port)
    transport = TTransport.TBufferedTransport(socket)
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    client = sdhashsrv.Client(protocol)
    transport.open()
    result = client.displayResult(resid)
    queryset = 'query'
    targetset = 'target'
    info = client.displayResultInfo(resid)
    stuff = info.split()
    if (info.count('--') == 1):
        queryset = stuff[0]
        targetset = stuff[2]
    else:
        queryset = stuff[0]
        targetset = queryset
    output = []
    header = []
    header.append('queryset')
    header.append('query')
    header.append('targetset')
    header.append('target')
    header.append('score')
    for line in result.split('\n'):
        cols = line.split('|')
        items = []
        cforw = cols[0].count('/')
        cback = cols[0].count('\\')
        if (len(cols) == 3):
            items.append(queryset)
            if (len(cols[0]) > 50):
                if (cback > 0):
                    fileparts = cols[0].rsplit('\\', 3)
                    items.append('...\\' + fileparts[1] + '\\' + fileparts[2] +
                                 '\\' + fileparts[3])
                if (cforw > 0):
                    fileparts = cols[0].rsplit('/', 3)
                    items.append('.../' + fileparts[1] + '/' + fileparts[2] +
                                 '/' + fileparts[3])
                else:
                    items.append('...' + cols[0][-50:])
            else:
                items.append(cols[0])
            items.append(targetset)
            cforw = cols[1].count('/')
            cback = cols[1].count('\\')
            if (len(cols[1]) > 50):
                if (cback > 0):
                    fileparts = cols[1].rsplit('\\', 3)
                    items.append('...\\' + fileparts[1] + '\\' + fileparts[2] +
                                 '\\' + fileparts[3])
                if (cforw > 0):
                    fileparts = cols[1].rsplit('/', 3)
                    items.append('.../' + fileparts[1] + '/' + fileparts[2] +
                                 '/' + fileparts[3])
            else:
                items.append(cols[1])
            items.append(cols[2])
            output.append(items)
    transport.close()
    return output
Exemple #5
0
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol

from h2oai_scoring import ScoringService
from h2oai_scoring.ttypes import Row

# -----------------------------------------------
# Name           Type      Range
# -----------------------------------------------
# L3_S30_D3521   float64   [0.4, 1718.39] or None
# -----------------------------------------------

socket = TSocket.TSocket('localhost', 9090)
transport = TTransport.TBufferedTransport(socket)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = ScoringService.Client(protocol)
transport.open()

server_hash = client.get_hash()
print('Scoring server hash: '.format(server_hash))

print('Scoring individual rows...')
row1 = Row()
row1.l3S30D3521 = 1660.7575097837375  # L3_S30_D3521

row2 = Row()
row2.l3S30D3521 = 757.5741980572641  # L3_S30_D3521

row3 = Row()
Exemple #6
0
def handler(event, context):
    start_time = time.time()
    worker_index = event['rank']
    num_workers = event['num_workers']
    host = event['host']
    port = event['port']
    size = event['size']

    print('number of workers = {}'.format(num_workers))
    print('worker index = {}'.format(worker_index))
    print("host = {}".format(host))
    print("port = {}".format(port))
    print("size = {}".format(size))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # register model
    ps_client.register_model(t_client, worker_index, MODEL_NAME, size,
                             num_workers)
    ps_client.exist_model(t_client, MODEL_NAME)
    print("register and check model >>> name = {}, length = {}".format(
        MODEL_NAME, size))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        for batch_index in range(NUM_BATCHES):
            print("------worker {} epoch {} batch {}------".format(
                worker_index, epoch, batch_index))
            batch_start = time.time()

            loss = 0.0

            # pull latest model
            ps_client.can_pull(t_client, MODEL_NAME, iter_counter,
                               worker_index)
            pull_start = time.time()
            latest_model = ps_client.pull_model(t_client, MODEL_NAME,
                                                iter_counter, worker_index)
            pull_time = time.time() - pull_start

            w_b_grad = np.random.rand(1, size).astype(np.double).flatten()

            # push gradient to PS
            ps_client.can_push(t_client, MODEL_NAME, iter_counter,
                               worker_index)
            push_start = time.time()
            ps_client.push_grad(t_client, MODEL_NAME, w_b_grad, LEARNING_RATE,
                                iter_counter, worker_index)
            push_time = time.time() - push_start
            ps_client.can_pull(t_client, MODEL_NAME, iter_counter + 1,
                               worker_index)  # sync all workers

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: pull model cost %.4f s, push update cost %.4f s'
                % (epoch + 1, NUM_EPOCHS, batch_index, NUM_BATCHES,
                   time.time() - train_start, loss, time.time() - epoch_start,
                   time.time() - batch_start, pull_time, push_time))
            iter_counter += 1

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
Exemple #7
0
    def connect(self, host=None, port=None, uri=None, timeout=10000):
        """
        Connect method should be called before any operations.
        Server will be connected after connect return OK

        :type  host: str
        :type  port: str
        :type  uri: str
        :type  timeout: int
        :param timeout: (Optional) connection timeout, ms / 1000
        :param host: (Optional) host of the server, default host is 127.0.0.1
        :param port: (Optional) port of the server, default port is 9090
        :param uri: (Optional) only support tcp proto now, default uri is

                `tcp://127.0.0.1:9090`

        :return: Status, indicate if connect is successful
        :rtype: Status
        """
        if self.status and self.status == Status.SUCCESS:
            raise RepeatingConnectError("You have already connected!")

        config_uri = urlparse(config.THRIFTCLIENT_TRANSPORT)

        _uri = urlparse(uri) if uri else config_uri

        if not host:
            if _uri.scheme == 'tcp':
                host = _uri.hostname
                port = _uri.port or 9090
            else:
                if uri:
                    raise RuntimeError('Invalid parameter uri: {}'.format(uri))
                raise RuntimeError(
                    'Invalid configuration for THRIFTCLIENT_TRANSPORT: {transport}'
                    .format(transport=config.THRIFTCLIENT_TRANSPORT))
        else:
            host = host
            port = port or 9090

        self._transport = TSocket.TSocket(host, port)

        if timeout:
            self._transport.setTimeout(int(timeout))

        if config.THRIFTCLIENT_BUFFERED:
            self._transport = TTransport.TBufferedTransport(self._transport)
        if config.THRIFTCLIENT_ZLIB:
            self._transport = TZlibTransport.TZlibTransport(self._transport)
        if config.THRIFTCLIENT_FRAMED:
            self._transport = TTransport.TFramedTransport(self._transport)

        if config.THRIFTCLIENT_PROTOCOL == Protocol.BINARY:
            protocol = TBinaryProtocol.TBinaryProtocol(self._transport)

        elif config.THRIFTCLIENT_PROTOCOL == Protocol.COMPACT:
            protocol = TCompactProtocol.TCompactProtocol(self._transport)

        elif config.THRIFTCLIENT_PROTOCOL == Protocol.JSON:
            protocol = TJSONProtocol.TJSONProtocol(self._transport)

        else:
            raise RuntimeError(
                "invalid configuration for THRIFTCLIENT_PROTOCOL: {protocol}".
                format(protocol=config.THRIFTCLIENT_PROTOCOL))

        self._client = MilvusService.Client(protocol)

        try:
            self._transport.open()
            self.status = Status(Status.SUCCESS, 'Connected')
            return self.status

        except TTransport.TTransportException as e:
            self.status = Status(code=e.type, message=e.message)
            LOGGER.error(e)
            raise NotConnectError('Connection failed')
 def get_client(self, host, port):
     trans = TSocket.TSocket(host, port)
     trans = TTransport.TBufferedTransport(trans)
     proto = TBinaryProtocol.TBinaryProtocolAccelerated(trans)
     client = KnnThriftService.Client(proto)
     return client, trans
Exemple #9
0

# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.10/howto/deployment/checklist/

# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = config.get('default','SECRET_KEY')

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True

ALLOWED_HOSTS = []


TRANSPORT = TSocket.TSocket(config.get('thrift','ADDRESS'), config.get('thrift','PORT'))
TRANSPORT = TTransport.TBufferedTransport(TRANSPORT)
PROTOCOL = TBinaryProtocol.TBinaryProtocol(TRANSPORT)

TRANSPORT.open()


# Application definition

INSTALLED_APPS = [
    'home.apps.HomeConfig',
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
Exemple #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-server', nargs='+')
    parser.add_argument('-set', nargs='+')
    parser.add_argument('-get', nargs='+')
    parser.add_argument('-del', nargs='+')
    args = parser.parse_args()

    # Make socket
    if 'server' in args:
        host_port = getattr(args, 'server')[0].split(':')
        if len(host_port) != 2:
            sys.stderr.write(
                'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>'
            )
            return
        else:
            transport = TSocket.TSocket(host_port[0], host_port[1])
    else:
        transport = TSocket.TSocket('127.0.0.1', 9090)

    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)

    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)

    # Create a client to use the protocol encoder
    client = Client(protocol)

    # Connect!
    transport.open()

    # Test!
    result = None
    if getattr(args, 'set') != None:
        key_value = getattr(args, 'set')
        if len(key_value) != 2:
            sys.stderr.write(
                'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>'
            )
            return
        else:
            result = set_kv(client, key_value[0], key_value[1])
    if getattr(args, 'get') != None:
        key_file = getattr(args, 'get')
        if len(key_file) != 2:
            sys.stderr.write(
                'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>'
            )
            return
        else:
            result = get_k(client, key_file[0])
            if result.error == 0:
                print result.value
                f = open(key_file[1], 'w')
                f.write(result.value)
                f.close
    if getattr(args, 'del') != None:
        key = getattr(args, 'del')
        if len(key) != 1:
            sys.stderr.write(
                'Usage: KVStoreClient -server <host:port> -CMD_NAME <key w/ val w/ outputFile>'
            )
            return
        else:
            result = del_k(client, key[0])

    result_printer(result)

    # Close!
    transport.close()
Exemple #11
0
def handler(event, context):
    startTs = time.time()
    num_features = event['num_features']
    learning_rate = event["learning_rate"]
    batch_size = event["batch_size"]
    num_epochs = event["num_epochs"]
    validation_ratio = event["validation_ratio"]

    # Reading data from S3
    bucket_name = event['bucket_name']
    key = urllib.parse.unquote_plus(event['key'], encoding='utf-8')
    print(f"Reading training data from bucket = {bucket_name}, key = {key}")
    key_splits = key.split("_")
    worker_index = int(key_splits[0])
    num_worker = int(key_splits[1])

    # read file from s3
    file = get_object(bucket_name, key).read().decode('utf-8').split("\n")
    print("read data cost {} s".format(time.time() - startTs))

    parse_start = time.time()
    dataset = SparseDatasetWithLines(file, num_features)
    print("parse data cost {} s".format(time.time() - parse_start))

    preprocess_start = time.time()
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_ratio * dataset_size))
    np.random.seed(42)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_set = [dataset[i] for i in train_indices]
    val_set = [dataset[i] for i in val_indices]

    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(constants.HOST, constants.PORT)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        constants.HOST, constants.PORT))

    svm = SparseSVM(train_set, val_set, num_features, num_epochs,
                    learning_rate, batch_size)

    # register model
    model_name = "w.b"
    model_length = num_features
    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             num_worker)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0

    # Training the Model
    for epoch in range(num_epochs):
        epoch_start = time.time()
        num_batches = math.floor(len(train_set) / batch_size)
        print(f"worker {worker_index} epoch {epoch}")

        for batch_idx in range(num_batches):
            batch_start = time.time()
            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            svm.weights = torch.from_numpy(latest_model).reshape(
                num_features, 1)

            batch_ins, batch_label = svm.next_batch(batch_idx)
            acc = svm.one_epoch(batch_idx, epoch)
            compute_end = time.time()

            sync_start = time.time()
            w_update = svm.weights - latest_model
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_update(t_client, model_name, w_update,
                                  learning_rate, iter_counter, worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            sync_time = time.time() - sync_start

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, train acc: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                % (epoch + 1, NUM_EPOCHS, batch_idx + 1,
                   len(train_indices) / batch_size, time.time() - train_start,
                   acc, time.time() - epoch_start, time.time() - batch_start,
                   compute_end - batch_start, sync_time))
            iter_counter += 1

        val_acc = svm.evaluate()
        print("Epoch takes {}s, validation accuracy: {}".format(
            time.time() - epoch_start, val_acc))
#!/usr/bin/env python

import sys

sys.path.append("gen-py")

from thrift.transport import TTransport
from thrift.transport import TSocket
from thrift.protocol import TBinaryProtocol

from risk import Rating
from risk.ttypes import *

transport = TTransport.TBufferedTransport(TSocket.TSocket("localhost", "9090"))
protocol = TBinaryProtocol.TBinaryProtocol(transport)

client = Rating.Client(protocol)
transport.open()

print client.piotroski("JAVA")
print client.altman_z("GOOG")
Exemple #13
0
def handler(event, context):
    start_time = time.time()
    bucket = event['data_bucket']
    worker_index = event['rank']
    num_worker = event['num_workers']
    key = event['key']

    print('bucket = {}'.format(bucket))
    print('number of workers = {}'.format(num_worker))
    print('worker index = {}'.format(worker_index))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(constants.HOST, constants.PORT)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()

    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        constants.HOST, constants.PORT))

    #bucket = "cifar10dataset"

    print('data_bucket = {}\n worker_index:{}\n num_worker:{}\n key:{}'.format(
        bucket, worker_index, num_worker, key))

    # read file from s3
    readS3_start = time.time()
    train_path = download_file(bucket, key)
    trainset = torch.load(train_path)
    test_path = download_file(bucket, test_file)
    testset = torch.load(test_path)

    print("read data cost {} s".format(time.time() - readS3_start))
    preprocess_start = time.time()
    batch_size = 200
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=batch_size,
                                              shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')
    device = 'cpu'
    print("preprocess data cost {} s".format(time.time() - preprocess_start))

    model = MobileNet()
    model = model.to(device)

    # Loss and Optimizer
    # Softmax is internally computed.
    # Set parameters to be updated.
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=5e-4)

    # register model
    model_name = "mobilenet"
    parameter_shape = []
    parameter_length = []
    model_length = 0
    for param in model.parameters():
        tmp_shape = 1
        parameter_shape.append(param.data.numpy().shape)
        for w in param.data.numpy().shape:
            tmp_shape *= w
        parameter_length.append(tmp_shape)
        model_length += tmp_shape

    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             num_worker)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(num_epochs):
        epoch_start = time.time()
        model.train()
        for batch_index, (inputs, targets) in enumerate(train_loader):
            print("------worker {} epoch {} batch {}------".format(
                worker_index, epoch, batch_index))
            batch_start = time.time()

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            pos = 0
            for layer_index, param in enumerate(model.parameters()):
                param.data = Variable(
                    torch.from_numpy(
                        np.asarray(latest_model[pos:pos +
                                                parameter_length[layer_index]],
                                   dtype=np.float32).reshape(
                                       parameter_shape[layer_index])))
                pos += parameter_length[layer_index]

            # Forward + Backward + Optimize
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()

            # flatten and concat gradients of weight and bias
            param_grad = np.zeros((1))
            for param in model.parameters():
                #print("shape of layer = {}".format(param.data.numpy().flatten().shape))
                param_grad = np.concatenate(
                    (param_grad, param.data.numpy().flatten()))
            param_grad = np.delete(param_grad, 0)
            print("model_length = {}".format(param_grad.shape))

            # push gradient to PS
            sync_start = time.time()
            print(
                ps_client.can_push(t_client, model_name, iter_counter,
                                   worker_index))
            print(
                ps_client.push_grad(t_client, model_name, param_grad,
                                    learning_rate, iter_counter, worker_index))
            print(
                ps_client.can_pull(t_client, model_name, iter_counter + 1,
                                   worker_index))  # sync all workers
            sync_time = time.time() - sync_start

            print(
                'Epoch: [%d/%d], Step: [%d/%d] >>> Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                % (epoch + 1, num_epochs, batch_index + 1,
                   len(train_indices) / batch_size, time.time() - train_start,
                   loss.data, time.time() - epoch_start,
                   time.time() - batch_start, cal_time, sync_time))
            iter_counter += 1

            test(epoch, model, test_loader, criterion, device)
            optimizer.step()
Exemple #14
0
def vector_addition_dfe(size, scalar, in_a, in_b):
    """VectorAddition DFE implementation."""
    try:
        start_time = time.time()
        # Make socket
        socket = TSocket.TSocket('localhost', 9090)

        # Buffering is critical. Raw sockets are very slow
        transport = TTransport.TBufferedTransport(socket)

        # Wrap in a protocol
        protocol = TBinaryProtocol.TBinaryProtocol(transport)

        # Create a client to use the protocol encoder
        client = VectorAdditionService.Client(protocol)
        print('Creating a client:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Connect!
        start_time = time.time()
        transport.open()
        print('Opening connection:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Initialize maxfile
        start_time = time.time()
        max_file = client.VectorAddition_init()
        print('Initializing maxfile:\t\t\t\t%.5lfs' %
              (time.time() - start_time))

        # Load DFE
        start_time = time.time()
        max_engine = client.max_load(max_file, '*')
        print('Loading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Allocate and send input streams to server
        start_time = time.time()
        address_in_a = client.malloc_int32_t(size)
        client.send_data_int32_t(address_in_a, in_a)

        address_in_b = client.malloc_int32_t(size)
        client.send_data_int32_t(address_in_b, in_b)
        print('Sending input data:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Allocate memory for output stream on server
        start_time = time.time()
        address_data_out = client.malloc_int32_t(size)
        print('Allocating memory for output stream on server:\t%.5lfs' %
              (time.time() - start_time))

        # Write vector a to LMem
        start_time = time.time()
        action = client.max_actions_init(max_file, "writeLMem")
        client.max_set_param_uint64t(action, "address", 0)
        client.max_set_param_uint64t(action, "nbytes", size * 4)
        client.max_queue_input(action, "cpu_to_lmem", address_in_a, size * 4)
        client.max_run(max_engine, action)
        print('Writing to LMem:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Add two vectors and a scalar
        start_time = time.time()
        action = client.max_actions_init(max_file, "default")
        client.max_set_param_uint64t(action, "N", size)
        client.max_set_param_uint64t(action, "A", scalar)
        client.max_queue_input(action, "y", address_in_b, size * 4)
        client.max_queue_output(action, "s", address_data_out, size * 4)
        client.max_run(max_engine, action)
        print('Vector addition time:\t\t\t\t%.5lfs' %
              (time.time() - start_time))

        # Unload DFE
        start_time = time.time()
        client.max_unload(max_engine)
        print('Unloading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Get output stream from server
        start_time = time.time()
        data_out = client.receive_data_int32_t(address_data_out, size)
        print('Getting output stream:\t(size = %d bit)\t%.5lfs' %
              ((size * 32), (time.time() - start_time)))

        # Free allocated memory for streams on server
        start_time = time.time()
        client.free(address_in_a)
        client.free(address_in_b)
        client.free(address_data_out)
        client.free(action)
        print('Freeing allocated memory for streams on server:\t%.5lfs' %
              (time.time() - start_time))

        # Free allocated maxfile data
        start_time = time.time()
        client.VectorAddition_free()
        print('Freeing allocated maxfile data:\t\t\t%.5lfs' %
              (time.time() - start_time))

        # Close!
        start_time = time.time()
        transport.close()
        print('Closing connection:\t\t\t\t%.5lfs' % (time.time() - start_time))

    except Thrift.TException, thrift_exceptiion:
        print '%s' % (thrift_exceptiion.message)
        sys.exit(-1)
Exemple #15
0
 def __init__(self, hostname, port):
     self.transport = TTransport.TBufferedTransport(
         TSocket.TSocket(hostname, port))
     self.protocol = TBinaryProtocol.TBinaryProtocol(self.transport)
     self.client = Hbase.Client(self.protocol)
     self.transport.open()
Exemple #16
0
def execute_hurdle(make_plot=False,
                   result_file="results.json",
                   host='127.0.0.1',
                   port=9090,
                   seed=None,
                   initial_state=None,
                   num_trials=10,
                   num_rounds=30000,
                   scoring_rounds=1000,
                   test_label="team"):
    num_states = 10

    avg_score_threshold = 2.0
    trial_pass_threshold = 6

    expected = expected_random_score(num_states)
    print("expected score of random guesser is {}".format(expected *
                                                          num_rounds))
    print("score required to pass a trial is {}".format(avg_score_threshold *
                                                        num_rounds))
    print("Number of trials passed to pass Hurdle 3 is  {} out of {}".format(
        trial_pass_threshold, num_trials))

    # Make socket
    transport = TSocket.TSocket(host, port)

    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)

    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)

    # Create a client to use the protocol encoder
    client = Hurdle3Execution.Client(protocol)

    # Connect!
    transport.open()

    results = {"trials": {}, "test_label": test_label}

    # set up a dedicated random number generator
    # for this object to guarantee repeatability
    # of solution evaluation without forcing
    # each trial to use the same seed.
    # This ensures the same set of trial seeds will be used
    # when using the same top level seed
    rng = np.random.RandomState(seed)

    for t in range(num_trials):

        # generate a unique seed per trial
        trial_seed = rng.randint(0, 0xffffffff)

        if initial_state is None:
            trial_initial_state = rng.choice(range(num_states))
        else:
            trial_initial_state = initial_state

        # create a new probabilistic state machine with potentaiily
        # a new initial state and seed
        # at the start of each trial
        psm = PSM(num_states, trial_initial_state, trial_seed)

        print("starting trial {} of {} with trial seed {}".format(
            t, num_trials, trial_seed))

        # run the trial and store the trial results
        trial_results = run_trial(t, num_rounds, scoring_rounds,
                                  avg_score_threshold, client, psm)

        # add seed and initial state to trial_results
        trial_results["seed"] = trial_seed
        trial_results["initial_state"] = trial_initial_state

        results["trials"][t] = trial_results

    # count the number of trials that passed
    trials_passed = sum(
        [results["trials"][i]["trial_pass"] for i in range(num_trials)])

    print(
        "Number of trials passed: {} Number of trials passed required to pass Hurdle 3: {}"
        .format(trials_passed, trial_pass_threshold))
    hurdle_pass = trials_passed >= trial_pass_threshold

    print("Hurdle 3 Passed? {}".format(hurdle_pass))

    results["main_seed"] = seed
    results["num_trials"] = num_trials
    results["num_states"] = num_states
    results["trials_passed"] = trials_passed
    results["trial_pass_threshold"] = trial_pass_threshold
    results["hurdle_pass"] = hurdle_pass

    with open(result_file, 'w') as f:

        f.write(json.dumps(results))

    print("Writing results to file: {}".format(result_file))
    #print("Results file: {}".format(results))

    client.stop()
Exemple #17
0
 def __init__(self):
     transport = TSocket.TSocket(config.IP, config.PORT)
     self.transport = TTransport.TBufferedTransport(transport)
     protocol = TBinaryProtocol.TBinaryProtocol(self.transport)
     self.client = ClassificationService.Client(protocol)
     self.transport.open()
Exemple #18
0
 def __init__(self, host, port):
     self.transport = TSocket.TSocket(host, port)
     self.transport = TTransport.TBufferedTransport(self.transport)
     self.protocol = TBinaryProtocol.TBinaryProtocol(self.transport)
     self.transport.open()
     self.client = GraphService.Client(self.protocol)
Exemple #19
0
def main(argv):
    global cursor
    global ifprop
    global vxtunprop
    global ret
    global vxret
    global collector_list

    collSessRet = 100
    collSessId = ''

    connection = pymonetdb.connect(username="******",
                                   password="******",
                                   hostname="localhost",
                                   database="voc")
    connection.set_autocommit(True)
    cursor = connection.cursor()
    cursor.arraysize = 100
    # Make socket
    transport = TSocket.TSocket('localhost', 34532)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)

    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)

    # Create a client to use the protocol encoder
    client = aev_config.Client(protocol)

    # Connect!
    transport.open()

    client.ping()
    #print('ping()')

    try:
        #DOCUMENTATION USAGE COMPRESSED SECTION#
        '''
        //Documentation
        //Usage of the parsing of getopt
        
        //Sample run of the function:
        
        ******************
        
        import getopt
        import sys
        
        version = "1.0"
        verbose = False
        output_filename = 'default.out'
        first_arg=""
        second_arg=""
        third_arg=""


        print ("arguments  :   " + str(sys.argv[1:]))
        
        options, remainder = getopt.getopt(sys.argv[1:], "o:", ['output=',
                                                         'verbose',
                                                         'version=',
                                                        'i1=',
                                                        'i2=',
                                                        'i3='
                                                         ])
        print ("OPTIONS   : "+ str(options))
        
        for opt, arg in options:
            if opt in ('-o', '--output'):
                output_filename = arg
            elif opt in ('-v', '--verbose'):
                verbose = True
            elif opt == '--version':
                version = arg
            elif opt in ('--i1'):
                first_arg = arg
            elif opt in ('--i2'):
                second_arg= arg
            elif opt in ('--i3'):
                third_arg = arg;
        
        print ('VERSION   :'+ version )
        print ('VERBOSE   :' )
        print(verbose )
        print ('OUTPUT    :'+ output_filename )
        print ('REMAINING :'+ str(remainder) )
        print ('i1' + first_arg)
        print('i2' + second_arg)
        print('i3' + third_arg)
        ******************

        //The output tells us about the must or mandatory arguments starts with '-' and optional '--' and the second list is about 
        some other arguments passed without - or -- usage.
        
        
        '''

        opts, args = getopt.getopt(argv, "m:c:n", [
            "i1=", "i2=", "i3=", "i4=", "i5=", "i6=", "i7=", "i8=", "i9=",
            "i10="
        ])

    #opts, args = getopt.getopt(argv,"hi:j:m:c:",["ifile=","ofile="])

    except getopt.GetoptError:
        print 'test.py -i <inputfile> -o <outputfile>'
        sys.exit(2)

    no = 0
    '''
    All the params passed by the clish are in the local scope of main function.
    '''
    param1 = param2 = param3 = param4 = param5 = param6 = param7 = param8 = param9 = param10 = None
    mode = command = None

    for opt, arg in opts:
        if opt in ("--i1"):
            param1 = arg
        elif opt in ("--i2"):
            param2 = arg
        elif opt in ("--i3"):
            param3 = arg
        elif opt in ("--i4"):
            param4 = arg
        elif opt in ("--i5"):
            param5 = arg
        elif opt in ("--i6"):
            param6 = arg
        elif opt in ("--i7"):
            param7 = arg
        elif opt in ("--i8"):
            param8 = arg
        elif opt in ("--i9"):
            param9 = arg
            print("\n" + param9)
        elif opt in ("--i10"):
            param10 = arg

        elif opt in ("-m"):
            mode = arg
        elif opt in ("-c"):
            command = arg
        elif opt in ("-n"):
            no = arg

    if mode == 'enable':
        if command == 'show_iface':
            selectstring = "SELECT * FROM if_table"
            cursor.execute(selectstring)
            print cursor.fetchmany()

    elif mode == 'config':
        '''
        If the mode is config and command is collsession, we call collector_session_check(sessionId)
        '''
        if command == "collsession":
            #Checked as int at CLI only
            collSessRet = collector_session_check(param1)

        elif command == "delTunnel":
            vxtunprop = aev_vxlan_tunnel_prop()
            delTunnelName = param1
            ##Send command to delete both table rows in case the tunnel name exists
            deleteRet = deleteBothTnlTbl(delTunnelName)
            if deleteRet == True:
                #send delete to the BCM server
                client.aev_if_prop_delete(1, vxtunprop)

            elif deleteRet == False:
                print(
                    '\n Such a tunnel name do not exists, please check the input'
                )

        elif command == "delcollsession":
            vxtunprop = aev_vxlan_tunnel_prop()
            delSessId = param1
            if (bool(str(delSessId).strip())
                ) == True or str(delSessId).strip() != '':
                print('\n The session id is not empty')
                #Acts to be performed:
                #1.Delete the session id row
                #2.Delete the session id from the tunnel table
                #3.Send delete to the board for aev
                print('\n The result of the delete from collector table is \n')
                print(delSessTabRow(delSessId))
                print(
                    '\n The result of update the session id from the tunnel table is \n'
                )
                tunnelName = deleteSessIdTnlTbl(delSessId)
                if tunnelName != 0:
                    print('\n The result of delete sent to the BCM board \n')
                    vxtunprop.tunnel_name = tunnelName
                    print('\n Send delete to the BCM \n')
                    example = client.aev_if_prop_delete(1, vxtunprop)
                    print(example)

                else:
                    print('\n the tunnelname was not found')

        elif command == "iface":
            ret = set_config_mode(param1)

        elif command == "tunnel":
            tunName = param2
            tunType = param1

            if (bool(str(tunName).strip())
                ) == False or str(tunName).strip() == '':
                print('\n The tunnel name is not as per the sepcification')
            flag = crupdateTnlTab(tunName, tunType)
            if flag == 1:
                print('\n New Tunnel row has been created and type inserted ')
            elif flag == 2:
                print('\n Tunnel row has been modified with new type')
            elif flag == 3:
                print('\n Tunnel row has not been modified and type exists')

    elif mode == "config-iface":
        ifprop = aev_if_prop()
        if command == 'speed':
            ret = set_speed(param1, param2)
        elif command == 'shutdown':
            ret = set_shutdown(param1)
        elif command == 'noshutdown':
            ret = set_noshutdown(param1)
        elif command == 'mtu':
            ret = set_mtu(param1, param2)
        elif command == 'autonego':
            ret = set_autonego(param1, param2)

            if ret == 0:
                example = client.aev_if_prop_update(1, ifprop)
                print(example)

    elif mode == "config_collector_sess":

        if command == "dest_tunnel":
            collSessId = param2  #Int
            tunName = param1  #String
            srcVlan = 0  #Int
            destTunnelRet = dstTunnelName(tunName, srcVlan, collSessId)

            if destTunnelRet == 0:
                print("Tunnel port_access is updated in VXLAN tunnel table")
            elif destTunnelRet == 1:
                print("Tunnel doesn't exist but the VLAN is updated")
            elif destTunnelRet == 2:
                print(
                    "\n The tunnel cannot be set as the source interface has not been found in the session table row"
                )
            elif destTunnelRet == -1:
                print("\n Tunnel exists but currently in use by other session")

        elif command == "sourceint":
            vxtunprop = aev_vxlan_tunnel_prop()
            collSessId = param2  #integer
            #Check
            print(type(collSessId))
            sourceInt = param1  #subcommand string

            #Check valid source interface or not
            sourceIntRet = checkValidInterface(sourceInt)

            #If the sourceIntRet is true as it should receive 1
            if sourceIntRet == True:
                #If the source interface is valid and free, try to update into collector table
                intSessRel = modifyNewCollectorTable(sourceInt, collSessId)

                if intSessRel == 0 or intSessRel == 1:
                    print(
                        "\n Already the interface exists with this session or the interface is assigned"
                    )
                    #Check destination tunnel is set or not, if it is we need to chec the ifindex or update it
                    dstTunnelTestRet = dstTunnelTest(sourceInt, collSessId)

                    if dstTunnelTestRet == True:
                        #True will be sent only if the vxtunprop is set by the method.
                        example = client.aev_vxlan_tunnel_create(1, vxtunprop)
                        print(example)

                    else:

                        print(
                            '\nEither the interface is pushed inside the vxlan_tnl_tab or the same interface was sent by the user'
                        )
                elif intSessRel == -1:
                    print("\n Can't help, somebody else is using it")

            elif sourceIntRet == False:
                print("Not a valid interface selected")

    elif mode == "config-tunnel":

        vxtunprop = aev_vxlan_tunnel_prop()

        if command == "tunnelprop":

            param = list()
            param.append(param1)
            param.append(param2)
            param.append(param3)
            param.append(param4)
            param.append(param5)
            param.append(param6)
            param.append(param7)
            param.append(param8)
            param.append(param9)
            param.append(param10)

            for vxretobj in updateVxlanTab(param):
                print(vxretobj)
                # 2 defines something is wrong with the provided parameters
                # 1 means all of them are set, need to send create to the BCM
                # -1 means we need to send delete to the BCM only with a tunnel name
                # 0 means the source interface is not set till now, so update the entries if they are working well
                if vxretobj == 0:
                    print(
                        '\n The values except the source interface has been updated into vx_tnl_tbl'
                    )

                if vxretobj == -1:
                    print('\n Send delete to the BCM')
                    ret = client.aev_if_prop_delete(1, vxtunprop)
                    if ret == True:
                        deleteQueryString = "DELETE FROM vxlan_tnl_table WHERE vxlan_tnl_table.tnl_name='%s'" % vxtunprop.tunnel_name
                        cursor.execute(deleteQueryString)
                if vxretobj == 1:
                    example = client.aev_if_prop_create(1, vxtunprop)
                    print(example)

                if vxretobj == 2:
                    print("\n Something is wrong with provided parameters")

    connection.close()
    transport.close()
def simple_dfe(size, data_in):
    """Simple DFE implementation."""
    try:
        start_time = time.time()

        # Make socket
        socket = TSocket.TSocket('localhost', 9090)

        # Buffering is critical. Raw sockets are very slow
        transport = TTransport.TBufferedTransport(socket)

        # Wrap in a protocol
        protocol = TBinaryProtocol.TBinaryProtocol(transport)

        # Create a client to use the protocol encoder
        client = SimpleService.Client(protocol)
        print('Creating a client:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Connect!
        start_time = time.time()
        transport.open()
        print('Opening connection:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Initialize maxfile
        start_time = time.time()
        max_file = client.Simple_init()
        print('Initializing maxfile:\t\t\t\t%.5lfs' %
              (time.time() - start_time))

        # Load DFE
        start_time = time.time()
        max_engine = client.max_load(max_file, '*')
        print('Loading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Allocate and send input streams to server
        start_time = time.time()
        address_data_in = client.malloc_float(size)
        client.send_data_float(address_data_in, data_in)
        print('Sending input data:\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Allocate memory for output stream on server
        start_time = time.time()
        address_data_out = client.malloc_float(size)
        print('Allocating memory for output stream on server:\t%.5lfs' %
              (time.time() - start_time))

        # Action default
        start_time = time.time()
        actions = Simple_actions_t_struct(size, address_data_in,
                                          address_data_out)
        address_actions = client.send_Simple_actions_t(actions)
        client.Simple_run(max_engine, address_actions)
        print('Simple time:\t\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Unload DFE
        start_time = time.time()
        client.max_unload(max_engine)
        print('Unloading DFE:\t\t\t\t\t%.5lfs' % (time.time() - start_time))

        # Get output stream from server
        start_time = time.time()
        data_out = client.receive_data_float(address_data_out, size)
        print('Getting output stream:\t(size = %d bit)\t%.5lfs' %
              ((size * 32), (time.time() - start_time)))

        # Free allocated memory for streams on server
        start_time = time.time()
        client.free(address_data_in)
        client.free(address_data_out)
        client.free(address_actions)
        print('Freeing allocated memory for streams on server:\t%.5lfs' %
              (time.time() - start_time))

        # Free allocated maxfile data
        start_time = time.time()
        client.Simple_free()
        print('Freeing allocated maxfile data:\t\t\t%.5lfs' %
              (time.time() - start_time))

        # Close!
        start_time = time.time()
        transport.close()
        print('Closing connection:\t\t\t\t%.5lfs' % (time.time() - start_time))

    except Thrift.TException, thrift_exceptiion:
        print '%s' % (thrift_exceptiion.message)
        sys.exit(-1)
from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport

from thriftDemo.service.hello import HelloService

if __name__ == '__main__':
    server = TSocket.TSocket("localhost", 12580)
    transport = TTransport.TBufferedTransport(server)
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    client = HelloService.Client(protocol)
    transport.open()
    pong = client.ping("hello, world!")
    print("pong %s" % pong)
    transport.close()
Exemple #22
0
def handler(event, context):
    start_time = time.time()

    # dataset setting
    train_file = event['train_file']
    test_file = event['test_file']
    data_bucket = event['data_bucket']
    n_features = event['n_features']
    n_classes = event['n_classes']
    n_workers = event['n_workers']
    worker_index = event['worker_index']
    cp_bucket = event['cp_bucket']

    # ps setting
    host = event['host']
    port = event['port']

    # training setting
    model_name = event['model']
    optim = event['optim']
    sync_mode = event['sync_mode']
    assert model_name.lower() in MLModel.Deep_Models
    assert optim.lower() in Optimization.Grad_Avg
    assert sync_mode.lower() in Synchronization.Reduce

    # hyper-parameter
    learning_rate = event['lr']
    batch_size = event['batch_size']
    n_epochs = event['n_epochs']
    start_epoch = event['start_epoch']
    run_epochs = event['run_epochs']

    function_name = event['function_name']

    print('data bucket = {}'.format(data_bucket))
    print("train file = {}".format(train_file))
    print("test file = {}".format(test_file))
    print('number of workers = {}'.format(n_workers))
    print('worker index = {}'.format(worker_index))
    print('model = {}'.format(model_name))
    print('optimization = {}'.format(optim))
    print('sync mode = {}'.format(sync_mode))
    print('start epoch = {}'.format(start_epoch))
    print('run epochs = {}'.format(run_epochs))
    print('host = {}'.format(host))
    print('port = {}'.format(port))

    print("Run function {}, round: {}/{}, epoch: {}/{} to {}/{}".format(
        function_name,
        int(start_epoch / run_epochs) + 1, math.ceil(n_epochs / run_epochs),
        start_epoch + 1, n_epochs, start_epoch + run_epochs, n_epochs))

    # download file from s3
    storage = S3Storage()
    local_dir = "/tmp"
    read_start = time.time()
    storage.download(data_bucket, train_file,
                     os.path.join(local_dir, train_file))
    storage.download(data_bucket, test_file,
                     os.path.join(local_dir, test_file))
    print("download file from s3 cost {} s".format(time.time() - read_start))

    train_set = torch.load(os.path.join(local_dir, train_file))
    test_set = torch.load(os.path.join(local_dir, test_file))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True)
    n_train_batch = len(train_loader)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=100,
                                              shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    print("read data cost {} s".format(time.time() - read_start))

    random_seed = 100
    torch.manual_seed(random_seed)

    device = 'cpu'
    model = deep_models.get_models(model_name).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # load checkpoint model if it is not the first round
    if start_epoch != 0:
        checked_file = 'checkpoint_{}.pt'.format(start_epoch - 1)
        storage.download(cp_bucket, checked_file,
                         os.path.join(local_dir, checked_file))
        checkpoint_model = torch.load(os.path.join(local_dir, checked_file))

        model.load_state_dict(checkpoint_model['model_state_dict'])
        optimizer.load_state_dict(checkpoint_model['optimizer_state_dict'])
        print("load checkpoint model at epoch {}".format(start_epoch - 1))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()
    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # register model
    parameter_shape = []
    parameter_length = []
    model_length = 0
    for param in model.parameters():
        tmp_shape = 1
        parameter_shape.append(param.data.numpy().shape)
        for w in param.data.numpy().shape:
            tmp_shape *= w
        parameter_length.append(tmp_shape)
        model_length += tmp_shape

    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             n_workers)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(start_epoch, min(start_epoch + run_epochs, n_epochs)):

        model.train()
        epoch_start = time.time()

        train_acc = Accuracy()
        train_loss = Average()

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            batch_start = time.time()
            batch_cal_time = 0
            batch_comm_time = 0

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            pos = 0
            for layer_index, param in enumerate(model.parameters()):
                param.data = Variable(
                    torch.from_numpy(
                        np.asarray(latest_model[pos:pos +
                                                parameter_length[layer_index]],
                                   dtype=np.float32).reshape(
                                       parameter_shape[layer_index])))
                pos += parameter_length[layer_index]
            batch_comm_time += time.time() - batch_start

            batch_cal_start = time.time()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            optimizer.zero_grad()
            loss.backward()

            # flatten and concat gradients of weight and bias
            param_grad = np.zeros((1))
            for param in model.parameters():
                # print("shape of layer = {}".format(param.data.numpy().flatten().shape))
                param_grad = np.concatenate(
                    (param_grad, param.data.numpy().flatten()))
            param_grad = np.delete(param_grad, 0)
            #print("model_length = {}".format(param_grad.shape))
            batch_cal_time += time.time() - batch_cal_start

            # push gradient to PS
            batch_push_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_grad(t_client, model_name, param_grad,
                                -1. * learning_rate / n_workers, iter_counter,
                                worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            batch_comm_time += time.time() - batch_push_start

            train_acc.update(outputs, targets)
            train_loss.update(loss.item(), inputs.size(0))

            optimizer.step()
            iter_counter += 1

            if batch_idx % 10 == 0:
                print(
                    'Epoch: [%d/%d], Batch: [%d/%d], Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                    'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                    % (epoch + 1, n_epochs, batch_idx + 1, n_train_batch,
                       time.time() - train_start, loss.item(),
                       time.time() - epoch_start, time.time() - batch_start,
                       batch_cal_time, batch_comm_time))

        test_loss, test_acc = test(epoch, model, test_loader)

        print(
            'Epoch: {}/{},'.format(epoch + 1, n_epochs),
            'train loss: {},'.format(train_loss),
            'train acc: {},'.format(train_acc),
            'test loss: {},'.format(test_loss),
            'test acc: {}.'.format(test_acc),
        )

    # training is not finished yet, invoke next round
    if epoch < n_epochs - 1:
        checkpoint_model = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss.average
        }

        checked_file = 'checkpoint_{}.pt'.format(epoch)

        if worker_index == 0:
            torch.save(checkpoint_model, os.path.join(local_dir, checked_file))
            storage.upload(cp_bucket, checked_file,
                           os.path.join(local_dir, checked_file))
            print("checkpoint model at epoch {} saved!".format(epoch))

        print(
            "Invoking the next round of functions. round: {}/{}, start epoch: {}, run epoch: {}"
            .format(
                int((epoch + 1) / run_epochs) + 1,
                math.ceil(n_epochs / run_epochs), epoch + 1, run_epochs))
        lambda_client = boto3.client('lambda')
        payload = {
            'train_file': event['train_file'],
            'test_file': event['test_file'],
            'data_bucket': event['data_bucket'],
            'n_features': event['n_features'],
            'n_classes': event['n_classes'],
            'n_workers': event['n_workers'],
            'worker_index': event['worker_index'],
            'cp_bucket': event['cp_bucket'],
            'host': event['host'],
            'port': event['port'],
            'model': event['model'],
            'optim': event['optim'],
            'sync_mode': event['sync_mode'],
            'lr': event['lr'],
            'batch_size': event['batch_size'],
            'n_epochs': event['n_epochs'],
            'start_epoch': epoch + 1,
            'run_epochs': event['run_epochs'],
            'function_name': event['function_name']
        }
        lambda_client.invoke(FunctionName=function_name,
                             InvocationType='Event',
                             Payload=json.dumps(payload))

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
def findHeavyHitters(table, today=datetime.date.today(), verbose=False):
    """
  Find heavy hitters in the given traffic (table) and store the results in the 'suspiciousheavyhitters' Hive table.
  """

    histNbDay = 15
    date = "%d%02d%02d" % (today.year, today.month, today.day)
    dates = list(
        "%d%02d%02d" % (x.year, x.month, x.day)
        for x in pd.date_range(today - datetime.timedelta(histNbDay), today -
                               datetime.timedelta(1)))
    table = scrub(table)

    ## set some variables regarding the input data
    if table.startswith("netflow"):
        dataType = "netflow"
        endpointTypes = [("dstip", "da"), ("srcip", "sa")]
        req0 = "select {endpoint}, sum(ipkt) nbpkt, sum(ibyt) nbbyte from {table} where dt=%s group by {endpoint}"
        req1 = "select {genericLabel}, avg(nbpkt) as avgpkt, stddev_samp(nbpkt) as stdpkt, avg(nbbyt) as avgbyt, stddev_samp(nbbyt) as stdbyt from(select {endpointType} as {genericLabel}, dt, sum(ipkt) as nbpkt, sum(ibyt) as nbbyt from {table} where {endpointType} IN ({suspiciousIP}) and dt IN ({dates}) group by {endpointType}, dt order by {endpointType}, dt) group by {genericLabel}"
    elif table.startswith("sflow"):
        dataType = "sflow"
        endpointTypes = [("dstip", "dstip"), ("srcip", "srcip"),
                         ("dstip", "dstip6"), ("srcip", "srcip6")]
        req0 = "select {endpoint}, count(*) nbpkt, sum(ipsize) nbbyte from {table} where dt=%s and {endpoint}<>'' group by {endpoint}"
        req1 = "select {genericLabel}, avg(nbpkt) as avgpkt, stddev_samp(nbpkt) as stdpkt, avg(nbbyt) as avgbyt, stddev_samp(nbbyt) as stdbyt from(select {endpointType} as {genericLabel}, dt, count(*) as nbpkt, sum(ipsize) as nbbyt from {table} where {endpointType} IN ({suspiciousIP}) and dt IN ({dates}) group by {endpointType}, dt order by {endpointType}, dt) group by {genericLabel}"
    else:
        sys.stderr.write("Data type unknown!")
        sys.exit(-1)

    outputFile = open(
        "%s/suspiciousheavyhitters_%s_%s.txt" % (outputDirectory, table, date),
        "w")
    cursor = presto.connect('localhost').cursor()
    for genericLabel, endpointType in endpointTypes:
        if verbose:
            sys.stdout.write("Looking for %s heavy hitters... (%s,%s)\n" %
                             (date, table, genericLabel))
        suspiciousIP = set()
        # get today's data
        formatedReq = req0.format(endpoint=endpointType, table=table)
        cursor.execute(formatedReq, [date])
        res = cursor.fetchall()

        if len(res) == 0:
            continue

        data = pd.DataFrame(res, columns=[genericLabel, "nbpkt", "nbbyt"])
        data.index = data.pop(genericLabel)

        # find today's heavy hitter
        for aggType in ["nbpkt", "nbbyt"]:
            suspiciousIP.update(
                data.ix[data[aggType] > data[aggType].mean() +
                        3 * data[aggType].std()].index.tolist())

        # check in past data if they had similar behavior
        if verbose: sys.stdout.write("Retrieve past data...\n")
        suspiciousIP = list(suspiciousIP)
        for i in range(len(suspiciousIP))[::100]:
            susIP = suspiciousIP[i:i + 100]
            formatedReq1 = req1.format(
                genericLabel=genericLabel,
                endpointType=endpointType,
                table=table,
                suspiciousIP=str.translate(str(list(susIP)), None, "u[]"),
                dates=str.translate(str(dates), None, "u[]"))
            cursor.execute(formatedReq1)
            res = cursor.fetchall()

            if verbose: sys.stdout.write("Register suspicious IPs...\n")
            for ip, avgpkt, stdpkt, avgbyt, stdbyt in res:
                currData = data.ix[ip]
                if genericLabel == "dstip":
                    dstip = ip
                    srcip = ""
                else:
                    dstip = ""
                    srcip = ip
                try:
                    if currData["nbpkt"] > avgpkt + 3 * stdpkt or currData[
                            "nbbyt"] > avgbyt + 3 * stdbyt:
                        outputFile.write(
                            "%s\t%s\t%s\t%s\t%s\t\n" %
                            (srcip, dstip, currData["nbpkt"],
                             currData["nbbyt"],
                             confidence(currData["nbpkt"], avgpkt, stdpkt,
                                        currData["nbbyt"], avgbyt, stdbyt)))
                except TypeError:
                    if verbose:
                        sys.stdout.write(
                            "!!Warning!! no past data for %s (avgpkt=%s, stdpkt=%s, avgbyt=%s, stdbyt=%s)\n"
                            % (ip, avgpkt, stdpkt, avgbyt, stdbyt))
                    outputFile.write("%s\t%s\t%s\t%s\t%s\t\n" %
                                     (srcip, dstip, currData["nbpkt"],
                                      currData["nbbyt"], "MED"))
                    continue

    outputFile.close()

    # Store results in Hive
    try:
        transport = TSocket.TSocket('localhost', 10000)
        transport = TTransport.TBufferedTransport(transport)
        protocol = TBinaryProtocol.TBinaryProtocol(transport)

        client = ThriftHive.Client(protocol)
        transport.open()

        client.execute(
            "create table if not exists suspiciousheavyhitters (srcip string, dstip string, pkt bigint, byte bigint, confidence string) partitioned by(dt string, dataSrc string) row format delimited fields terminated by '\t'"
        )
        client.execute(
            "load data local inpath '{dir}/suspiciousheavyhitters_{table}_{date}.txt' overwrite into table suspiciousheavyhitters partition (dt='{date}', dataSrc='{table}')"
            .format(table=table, date=date, dir=outputDirectory))
        transport.close()

    except Thrift.TException, tx:
        sys.stderr.write('%s\n' % (tx.message))
Exemple #24
0
    def connect_thrift_node(self, node_to_connect):
        """ attempt to connect a node through thrift networking services """
        connection_successful = False
        if not node_to_connect.connected:
            try:
                if node_to_connect not in self.connections:  # and \
                    # peer.connection_attempts < MAX_CONNECTION_ATTEMPTS and \
                    # len(self.peers) < self.max_outbound_connections:
                    logger().info('attempting connect_thrift_node %s:%s',
                                  node_to_connect.host, node_to_connect.port)

                    pass_phrase = str(uuid.uuid4())

                    # Make socket
                    transport = TSocket.TSocket(node_to_connect.host,
                                                int(node_to_connect.port))

                    # Buffering is critical. Raw sockets are very slow
                    transport = TTransport.TBufferedTransport(transport)

                    # Wrap in a protocol
                    protocol = TBinaryProtocol.TBinaryProtocol(transport)

                    # Create a client to use the protocol encoder
                    client = BlockchainService.Client(protocol)
                    # Connect
                    transport.open()
                    logger().info('about to register')
                    connection_successful = client.register_node(
                        self.this_node, pass_phrase)
                    logger().info('transport open to node %s',
                                  node_to_connect.node_id)
                    if connection_successful:
                        node_to_connect.connected = True
                        node_to_connect.pass_phrase = pass_phrase
                        node_to_connect.transport = transport
                        node_to_connect.client = client
                        logger().info(
                            '%s accepted outbound connection request.',
                            node_to_connect.node_id)
                        logger().info('node owner: %s', node_to_connect.owner)
                        logger().info('phases provided: %s',
                                      '{:05b}'.format(node_to_connect.phases))
                    else:
                        try:
                            net_dao.update_con_attempts(
                                node_to_connect
                            )  # incrementing connection attempts on fail
                        except Exception as ex:
                            template = "An exception of type {0} occured. Arguments:\n{1!r}"
                            message = template.format(
                                type(ex).__name__, ex.args)
                            logger().warning(message)
                        transport.close
                        print(node_to_connect.node_id +
                              ' rejected outbound connection request.')

            except Exception as ex:
                if not connection_successful:
                    net_dao.update_con_attempts(node_to_connect)
                template = "An exception of type {0} occured. Arguments:\n{1!r}"
                message = template.format(type(ex).__name__, ex.args)
                logger().warning(message)
            finally:
                logger().info('connect_thrift_node %s',
                              str(connection_successful))
        return connection_successful
import configparser

import os
import os.path

# Read in configs from config file
conffile = sys.argv[1]
config = configparser.ConfigParser()
config.sections()
config.read(conffile)
host = config['DEFAULT']['Host']
port = config['DEFAULT']['Port']

# Connect to HBase Thrift server: This creates the socket transport and line protocol and
# allows the Thrift client to connect and talk to the Thrift server.
transport = TTransport.TBufferedTransport(TSocket.TSocket(host, port))
protocol = TBinaryProtocol.TBinaryProtocolAccelerated(transport)

# Create and open the client connection: create the Client object you will be using to
# interact with HBase. From this client object, you will issue all your Gets and Puts.
client = Hbase.Client(protocol)
transport.open()

#########################################
# Define Table Name and Column Family Name Here
#########################################
tablename = 'haztable5'
cfname = 'cf1'

tableexists = 0
Exemple #26
0
    def do_POST(self):
        """
        Handles POST queries, which are usually Thrift messages.
        """

        client_host, client_port = self.client_address
        auth_session = self.__check_session_cookie()
        LOG.debug("%s:%s -- [%s] POST %s", client_host, str(client_port),
                  auth_session.user if auth_session else "Anonymous",
                  self.path)

        # Create new thrift handler.
        checker_md_docs = self.server.checker_md_docs
        checker_md_docs_map = self.server.checker_md_docs_map
        suppress_handler = self.server.suppress_handler
        version = self.server.version

        protocol_factory = TJSONProtocol.TJSONProtocolFactory()
        input_protocol_factory = protocol_factory
        output_protocol_factory = protocol_factory

        itrans = TTransport.TFileObjectTransport(self.rfile)
        itrans = TTransport.TBufferedTransport(
            itrans, int(self.headers['Content-Length']))
        otrans = TTransport.TMemoryBuffer()

        iprot = input_protocol_factory.getProtocol(itrans)
        oprot = output_protocol_factory.getProtocol(otrans)

        if self.server.manager.is_enabled and \
                not self.path.endswith('/Authentication') and \
                not auth_session:
            # Bail out if the user is not authenticated...
            # This response has the possibility of melting down Thrift clients,
            # but the user is expected to properly authenticate first.

            LOG.debug(client_host + ":" + str(client_port) +
                      " Invalid access, credentials not found " +
                      "- session refused.")
            self.send_error(401)
            return

        # Authentication is handled, we may now respond to the user.
        try:
            product_endpoint, api_ver, request_endpoint = \
                routing.split_client_POST_request(self.path)

            product = None
            if product_endpoint:
                # The current request came through a product route, and not
                # to the main endpoint.
                product = self.server.get_product(product_endpoint)
                self.__check_prod_db(product)

            version_supported = routing.is_supported_version(api_ver)
            if version_supported:
                major_version, _ = version_supported

                if major_version == 6:
                    if request_endpoint == 'Authentication':
                        auth_handler = AuthHandler_v6(
                            self.server.manager, auth_session,
                            self.server.config_session)
                        processor = AuthAPI_v6.Processor(auth_handler)
                    elif request_endpoint == 'Products':
                        prod_handler = ProductHandler_v6(
                            self.server, auth_session,
                            self.server.config_session, product, version)
                        processor = ProductAPI_v6.Processor(prod_handler)
                    elif request_endpoint == 'CodeCheckerService':
                        # This endpoint is a product's report_server.
                        if not product:
                            error_msg = "Requested CodeCheckerService on a " \
                                         "nonexistent product: '{0}'." \
                                        .format(product_endpoint)
                            LOG.error(error_msg)
                            raise ValueError(error_msg)

                        if product_endpoint:
                            # The current request came through a
                            # product route, and not to the main endpoint.
                            product = self.server.get_product(product_endpoint)
                            self.__check_prod_db(product)

                        acc_handler = ReportHandler_v6(
                            self.server.manager, product.session_factory,
                            product, auth_session, self.server.config_session,
                            checker_md_docs, checker_md_docs_map,
                            suppress_handler, version)
                        processor = ReportAPI_v6.Processor(acc_handler)
                    else:
                        LOG.debug("This API endpoint does not exist.")
                        error_msg = "No API endpoint named '{0}'." \
                                    .format(self.path)
                        raise ValueError(error_msg)

            else:
                if request_endpoint == 'Authentication':
                    # API-version checking is supported on the auth endpoint.
                    handler = BadAPIHandler(api_ver)
                    processor = AuthAPI_v6.Processor(handler)
                else:
                    # Send a custom, but valid Thrift error message to the
                    # client requesting this action.
                    error_msg = "Incompatible client/server API." \
                                "API versions supported by this server {0}." \
                                .format(get_version_str())

                    raise ValueError(error_msg)

            processor.process(iprot, oprot)
            result = otrans.getvalue()

            self.send_response(200)
            self.send_header("content-type", "application/x-thrift")
            self.send_header("Content-Length", len(result))
            self.end_headers()
            self.wfile.write(result)
            return

        except Exception as exn:
            # Convert every Exception to the proper format which can be parsed
            # by the Thrift clients expecting JSON responses.
            LOG.error(exn.message)
            import traceback
            traceback.print_exc()
            ex = TApplicationException(TApplicationException.INTERNAL_ERROR,
                                       exn.message)
            fname, _, seqid = iprot.readMessageBegin()
            oprot.writeMessageBegin(fname, TMessageType.EXCEPTION, seqid)
            ex.write(oprot)
            oprot.writeMessageEnd()
            oprot.trans.flush()
            result = otrans.getvalue()
            self.send_response(200)
            self.send_header("content-type", "application/x-thrift")
            self.send_header("Content-Length", len(result))
            self.end_headers()
            self.wfile.write(result)
            return
Exemple #27
0
    def __init__(
        self,
        uri=None,
        user=None,
        password=None,
        host=None,
        port=6274,
        dbname=None,
        protocol='binary',
        sessionid=None,
        bin_cert_validate=None,
        bin_ca_certs=None,
    ):

        self.sessionid = None
        if sessionid is not None:
            if any([user, password, uri, dbname]):
                raise TypeError("Cannot specify sessionid with user, password,"
                                " dbname, or uri")
        if uri is not None:
            if not all([
                    user is None, password is None, host is None, port == 6274,
                    dbname is None, protocol == 'binary',
                    bin_cert_validate is None, bin_ca_certs is None
            ]):
                raise TypeError("Cannot specify both URI and other arguments")
            user, password, host, port, dbname, protocol, \
                bin_cert_validate, bin_ca_certs = _parse_uri(uri)
        if host is None:
            raise TypeError("`host` parameter is required.")
        if protocol != 'binary' and not all(
            [bin_cert_validate is None, bin_ca_certs is None]):
            raise TypeError("Cannot specify bin_cert_validate or bin_ca_certs,"
                            " without binary protocol")
        if protocol in ("http", "https"):
            if not host.startswith(protocol):
                # the THttpClient expects http[s]://localhost
                host = '{0}://{1}'.format(protocol, host)
            transport = THttpClient.THttpClient("{}:{}".format(host, port))
            proto = TJSONProtocol.TJSONProtocol(transport)
            socket = None
        elif protocol == "binary":
            if any([bin_cert_validate is not None, bin_ca_certs]):
                socket = TSSLSocket.TSSLSocket(host,
                                               port,
                                               validate=(bin_cert_validate),
                                               ca_certs=bin_ca_certs)
            else:
                socket = TSocket.TSocket(host, port)
            transport = TTransport.TBufferedTransport(socket)
            proto = TBinaryProtocol.TBinaryProtocolAccelerated(transport)
        else:
            raise ValueError("`protocol` should be one of",
                             " ['http', 'https', 'binary'],",
                             " got {} instead".format(protocol))
        self._user = user
        self._password = password
        self._host = host
        self._port = port
        self._dbname = dbname
        self._transport = transport
        self._protocol = protocol
        self._socket = socket
        self._closed = 0
        self._tdf = None
        self._rbc = None
        try:
            self._transport.open()
        except TTransportException as e:
            if e.NOT_OPEN:
                err = OperationalError("Could not connect to database")
                raise err from e
            else:
                raise
        self._client = Client(proto)
        try:
            # If a sessionid was passed, we should validate it
            if sessionid:
                self._session = sessionid
                self.get_tables()
                self.sessionid = sessionid
            else:
                self._session = self._client.connect(user, password, dbname)
        except TMapDException as e:
            raise _translate_exception(e) from e
        except TTransportException:
            raise ValueError(f"Connection failed with port {port} and "
                             f"protocol '{protocol}'. Try port 6274 for "
                             "protocol == binary or 6273, 6278 or 443 for "
                             "http[s]")

        # if OmniSci version <4.6, raise RuntimeError, as data import can be
        # incorrect for columnar date loads
        # Caused by https://github.com/omnisci/pymapd/pull/188
        semver = self._client.get_version()
        if Version(semver.split("-")[0]) < Version("4.6"):
            raise RuntimeError(f"Version {semver} of OmniSci detected. "
                               "Please use pymapd <0.11. See release notes "
                               "for more details.")
Exemple #28
0
class thrift_utils:
    transport = TSocket.TSocket('192.168.56.1', 9090)
    transport = TTransport.TBufferedTransport(transport)
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    transport.open()
Exemple #29
0
import sys
#Hbase.thrift生成的py文件放在这里
sys.path.append('/usr/local/python2.7.3/lib/python2.7/site-packages/hbase')
from thrift import Thrift
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol
from hbase import Hbase
#如ColumnDescriptor 等在hbase.ttypes中定义
from hbase.ttypes import *
# Make socket
#此处可以修改地址和端口
transport = TSocket.TSocket('192.168.1.220', 9090)
# Buffering is critical. Raw sockets are very slow
# 还可以用TFramedTransport,也是高效传输方式
transport = TTransport.TBufferedTransport(transport)
# Wrap in a protocol
#传输协议和传输过程是分离的,可以支持多协议
protocol = TBinaryProtocol.TBinaryProtocol(transport)
#客户端代表一个用户
client = Hbase.Client(protocol)
#打开连接
transport.open()
#打印表名
print(client.getTableNames())
Exemple #30
0
def main():
    # Make socket
    transport = TSocket.TSocket('localhost', port)
    base_transport = TSocket.TSocket('localhost', base_port)

    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    base_transport = TTransport.TBufferedTransport(base_transport)

    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    base_protocol = TBinaryProtocol.TBinaryProtocol(base_transport)

    # Create a client to use the protocol encoder
    client = Drone.Client(protocol)
    base_client = BaseDoor.Client(base_protocol)

    # Connect!
    transport.open()
    base_transport.open()

    time.sleep(15)

    # Prepare doors!
    print("Opening doors...")
    base_client.openDoor()

    print("Clearing existing missions...")
    client.clear_missions()

    print("Downloading missions...")
    client.download_missions()

    print("Creating coordinate objects...")
    thrift_coordinate_list = []
    altitude = float(args.altitude)
    for count, coordinate in enumerate(coordinate_list):
        latlng = coordinate.split(",")
        lat = latlng[0]
        lng = latlng[1]

        if count == 0:
            first_coordinate = {
                "latitude": float(lat),
                "longitude": float(lng),
            }

            print("First waypoint is {0},{1}".format(
                first_coordinate["latitude"], first_coordinate["longitude"]))

        coordinate_obj = Coordinate(latitude=float(lat),
                                    longitude=float(lng),
                                    altitude=altitude)
        thrift_coordinate_list.append(coordinate_obj)

    print("Sending coordinates to server...")
    client.add_farm_mission(thrift_coordinate_list)

    transport.close()
    base_transport.close()

    print("Starting in flight status reports...")
    while (True):
        transport.open()

        print("Reporting flight status...")
        status_obj = client.report_status(int(args.drone_id))
        print("Armed: {0}".format(status_obj.armed))

        if first_coordinate[
                "latitude"] - 0.0001 <= status_obj.latitude <= first_coordinate[
                    "latitude"] + 0.0001 and first_coordinate[
                        "longitude"] - 0.0001 <= status_obj.latitude <= first_coordinate[
                            "longitude"] + 0.0001:
            print("Requesting camera start...")
            client.start_camera()

        if not status_obj.armed:
            end_mission_status_data = {"status": "Done"}

            end_drone_status_data = {"status": "Available"}

            end_mission_post = requests.patch(mission_endpoint,
                                              data=end_mission_status_data)

            end_drone_post = requests.patch(drone_endpoint,
                                            data=end_drone_status_data)

            client.change_mode("RTL")
            break
        transport.close()
        time.sleep(3)

    # Close!
    transport.close()