Exemple #1
0
def download_data(targets,
                  download_proxy=None,
                  task_name='download fashion-mnist'):
    """
    Download data.

    :param targets: target path for data.
    :param download_proxy: download proxy (e.g. 'http', 'https')
    :param task_name: name of the task
    """
    opener = urllib.request.build_opener()
    opener.addheaders = [('User-agent', 'Mozilla/5.0')]
    if download_proxy:
        proxy = urllib.request.ProxyHandler({
            'http': download_proxy,
            'https': download_proxy
        })
        opener.add_handler(proxy)

    urllib.request.install_opener(opener)
    with ProgressBar(total_length=len(targets), description=task_name) as t:
        for k, v in targets.items():
            if not os.path.exists(v['filename']):
                urllib.request.urlretrieve(
                    v['url'],
                    v['filename'],
                )
                t.update()
            if k == 'index-labels' or k == 'query-labels':
                v['data'] = load_labels(v['filename'])
            if k == 'index' or k == 'query':
                v['data'] = load_mnist(v['filename'])
Exemple #2
0
def download_data(targets,
                  download_proxy=None,
                  task_name='download fashion-mnist'):
    """
    Download data.

    :param targets: target path for data.
    :param download_proxy: download proxy (e.g. 'http', 'https')
    :param task_name: name of the task
    """
    opener = urllib.request.build_opener()
    opener.addheaders = [('User-agent', 'Mozilla/5.0')]
    if download_proxy:
        proxy = urllib.request.ProxyHandler({
            'http': download_proxy,
            'https': download_proxy
        })
        opener.add_handler(proxy)
    urllib.request.install_opener(opener)
    with ProgressBar(description=task_name) as t:
        for k, v in targets.items():
            if not os.path.exists(v['filename']):
                urllib.request.urlretrieve(
                    v['url'],
                    v['filename'],
                    reporthook=lambda *x: t.update(0.01))
Exemple #3
0
def test_kb_interrupt_in_progress_bar(capsys):
    with ProgressBar('update') as p:
        for j in range(100):
            p.update()
            time.sleep(0.01)
            if j > 5:
                raise KeyboardInterrupt
    captured = capsys.readouterr()
    assert str(ProgressBarStatus.CANCELED) in captured.out
Exemple #4
0
def test_progressbar(total_steps, update_tick, task_name, capsys, details,
                     msg_on_done):
    with ProgressBar(description=task_name, message_on_done=msg_on_done) as pb:
        for j in range(total_steps):
            pb.update(update_tick,
                      message=details.format(j) if details else None)
            time.sleep(0.001)

    captured = capsys.readouterr()
    assert 'estimating' in captured.out
Exemple #5
0
def test_error_in_progress_bar(capsys):
    with pytest.raises(NotImplementedError):
        with ProgressBar('update') as p:
            for j in range(100):
                p.update()
                time.sleep(0.01)
                if j > 5:
                    raise NotImplementedError
    captured = capsys.readouterr()
    assert str(ProgressBarStatus.ERROR) in captured.out
Exemple #6
0
def test_progressbar(total_steps, update_tick, task_name, capsys, msg_on_done):
    with ProgressBar(description=task_name,
                     message_on_done=msg_on_done,
                     total_length=total_steps) as pb:
        for j in range(total_steps):
            pb.update(advance=update_tick)
            time.sleep(0.001)

    captured = capsys.readouterr()
    assert 'ETA' in captured.out
Exemple #7
0
def download_data(target: dict, download_proxy=None):
    opener = urllib.request.build_opener()
    if download_proxy:
        proxy = urllib.request.ProxyHandler({
            'http': download_proxy,
            'https': download_proxy
        })
        opener.add_handler(proxy)
    urllib.request.install_opener(opener)
    with ProgressBar(task_name='download fashion-mnist', batch_unit='') as t:
        for k, v in target.items():
            if not os.path.exists(v['filename']):
                urllib.request.urlretrieve(v['url'],
                                           v['filename'],
                                           reporthook=lambda *x: t.update(1))
            if k == 'index-labels' or k == 'query-labels':
                v['data'] = load_labels(v['filename'])
            if k == 'index' or k == 'query':
                v['data'] = load_mnist(v['filename'])
Exemple #8
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        compression: Optional[str] = None,
        **kwargs,
    ):
        try:
            self.compression = (
                getattr(grpc.Compression, compression)
                if compression
                else grpc.Compression.NoCompression
            )

            self.inputs = inputs
            req_iter = self._get_requests(**kwargs)
            async with GrpcConnectionPool.get_grpc_channel(
                f'{self.args.host}:{self.args.port}',
                asyncio=True,
                tls=self.args.tls,
            ) as channel:
                stub = jina_pb2_grpc.JinaRPCStub(channel)
                self.logger.debug(f'connected to {self.args.host}:{self.args.port}')

                with ProgressBar(
                    total_length=self._inputs_length, disable=not (self.show_progress)
                ) as p_bar:

                    async for resp in stub.Call(
                        req_iter,
                        compression=self.compression,
                    ):
                        callback_exec(
                            response=resp,
                            on_error=on_error,
                            on_done=on_done,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                        if self.show_progress:
                            p_bar.update()
                        yield resp

        except KeyboardInterrupt:
            self.logger.warning('user cancel the process')
        except asyncio.CancelledError as ex:
            self.logger.warning(f'process error: {ex!r}')
        except (grpc.aio._call.AioRpcError, InternalNetworkError) as err:
            my_code = err.code()
            my_details = err.details()
            msg = f'gRPC error: {my_code} {my_details}'

            try:
                if my_code == grpc.StatusCode.UNAVAILABLE:
                    self.logger.error(
                        f'{msg}\nThe ongoing request is terminated as the server is not available or closed already.'
                    )
                    raise ConnectionError(my_details) from None
                elif my_code == grpc.StatusCode.DEADLINE_EXCEEDED:
                    self.logger.error(
                        f'{msg}\nThe ongoing request is terminated due to a server-side timeout.'
                    )
                    raise ConnectionError(my_details) from None
                elif my_code == grpc.StatusCode.INTERNAL:
                    self.logger.error(f'{msg}\ninternal error on the server side')
                    raise err
                elif (
                    my_code == grpc.StatusCode.UNKNOWN
                    and 'asyncio.exceptions.TimeoutError' in my_details
                ):
                    raise BadClientInput(
                        f'{msg}\n'
                        'often the case is that you define/send a bad input iterator to jina, '
                        'please double check your input iterator'
                    ) from err
                else:
                    raise BadClient(msg) from err

            except (
                grpc.aio._call.AioRpcError,
                BaseJinaException,
                ConnectionError,
            ) as e:  # depending on if there are callbacks we catch or not the exception
                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
Exemple #9
0
def test_never_call_update(capsys):
    with ProgressBar():
        pass
    captured = capsys.readouterr()
    assert captured.out.endswith(ProgressBar.clear_line)
Exemple #10
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = ProgressBar(
                    total_length=self._inputs_length, disable=not (self.show_progress)
                )
                p_bar = stack.enter_context(cm1)

                proto = 'wss' if self.args.tls else 'ws'
                url = f'{proto}://{self.args.host}:{self.args.port}/'
                iolet = await stack.enter_async_context(
                    WebsocketClientlet(url=url, logger=self.logger)
                )

                request_buffer: Dict[str, asyncio.Future] = dict()

                def _result_handler(result):
                    return result

                async def _receive():
                    def _response_handler(response):
                        if response.header.request_id in request_buffer:
                            future = request_buffer.pop(response.header.request_id)
                            future.set_result(response)
                        else:
                            self.logger.warning(
                                f'discarding unexpected response with request id {response.header.request_id}'
                            )

                    """Await messages from WebsocketGateway and process them in the request buffer"""
                    try:
                        async for response in iolet.recv_message():
                            _response_handler(response)
                    finally:
                        if request_buffer:
                            self.logger.warning(
                                f'{self.__class__.__name__} closed, cancelling all outstanding requests'
                            )
                            for future in request_buffer.values():
                                future.cancel()
                            request_buffer.clear()

                def _handle_end_of_iter():
                    """Send End of iteration signal to the Gateway"""
                    asyncio.create_task(iolet.send_eoi())

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For each request in the iterator, we send the `Message` using `iolet.send_message()`.
                    For websocket requests from client, for each request in the iterator, we send the request in `bytes`
                    using using `iolet.send_message()`.
                    Then add {<request-id>: <an-empty-future>} to the request buffer.
                    This empty future is used to track the `result` of this request during `receive`.
                    :param request: current request in the iterator
                    :return: asyncio Future for sending message
                    """
                    future = get_or_reuse_loop().create_future()
                    request_buffer[request.header.request_id] = future
                    asyncio.create_task(iolet.send_message(request))
                    return future

                streamer = RequestStreamer(
                    args=self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                    end_of_iter_handler=_handle_end_of_iter,
                )

                receive_task = get_or_reuse_loop().create_task(_receive())

                if receive_task.done():
                    raise RuntimeError(
                        'receive task not running, can not send messages'
                    )
                async for response in streamer.stream(request_iterator):
                    callback_exec(
                        response=response,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield response

            except aiohttp.ClientError as e:
                self.logger.error(
                    f'Error while streaming response from websocket server {e!r}'
                )

                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
Exemple #11
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        try:
            self.inputs = inputs
            req_iter = self._get_requests(**kwargs)
            async with GrpcConnectionPool.get_grpc_channel(
                    f'{self.args.host}:{self.args.port}',
                    asyncio=True,
                    https=self.args.https,
            ) as channel:
                stub = jina_pb2_grpc.JinaRPCStub(channel)
                self.logger.debug(
                    f'connected to {self.args.host}:{self.args.port}')

                cm1 = (ProgressBar(total_length=self._inputs_length)
                       if self.show_progress else nullcontext())

                with cm1 as p_bar:
                    async for resp in stub.Call(req_iter):
                        callback_exec(
                            response=resp,
                            on_error=on_error,
                            on_done=on_done,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                        if self.show_progress:
                            p_bar.update()
                        yield resp
        except KeyboardInterrupt:
            self.logger.warning('user cancel the process')
        except asyncio.CancelledError as ex:
            self.logger.warning(f'process error: {ex!r}')
        except grpc.aio._call.AioRpcError as rpc_ex:
            # Since this object is guaranteed to be a grpc.Call, might as well include that in its name.
            my_code = rpc_ex.code()
            my_details = rpc_ex.details()
            msg = f'gRPC error: {my_code} {my_details}'
            if my_code == grpc.StatusCode.UNAVAILABLE:
                self.logger.error(
                    f'{msg}\nthe ongoing request is terminated as the server is not available or closed already'
                )
                raise rpc_ex
            elif my_code == grpc.StatusCode.INTERNAL:
                self.logger.error(f'{msg}\ninternal error on the server side')
                raise rpc_ex
            elif (my_code == grpc.StatusCode.UNKNOWN
                  and 'asyncio.exceptions.TimeoutError' in my_details):
                raise BadClientInput(
                    f'{msg}\n'
                    'often the case is that you define/send a bad input iterator to jina, '
                    'please double check your input iterator') from rpc_ex
            else:
                raise BadClient(msg) from rpc_ex
Exemple #12
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = (ProgressBar(total_length=self._inputs_length)
                       if self.show_progress else nullcontext())
                p_bar = stack.enter_context(cm1)

                proto = 'https' if self.args.https else 'http'
                url = f'{proto}://{self.args.host}:{self.args.port}/post'
                iolet = await stack.enter_async_context(
                    HTTPClientlet(url=url, logger=self.logger))

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For HTTP Client, for each request in the iterator, we `send_message` using
                    http POST request and add it to the list of tasks which is awaited and yielded.
                    :param request: current request in the iterator
                    :return: asyncio Task for sending message
                    """
                    return asyncio.ensure_future(
                        iolet.send_message(request=request))

                def _result_handler(result):
                    return result

                streamer = RequestStreamer(
                    self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                )
                async for response in streamer.stream(request_iterator):
                    r_status = response.status

                    r_str = await response.json()
                    if r_status == 404:
                        raise BadClient(f'no such endpoint {url}')
                    elif r_status < 200 or r_status > 300:
                        raise ValueError(r_str)

                    da = None
                    if 'data' in r_str and r_str['data'] is not None:
                        from docarray import DocumentArray

                        da = DocumentArray.from_dict(r_str['data'])
                        del r_str['data']

                    resp = DataRequest(r_str)
                    if da is not None:
                        resp.data.docs = da

                    callback_exec(
                        response=resp,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield resp
            except aiohttp.ClientError as e:
                self.logger.error(
                    f'Error while fetching response from HTTP server {e!r}')
Exemple #13
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        compression: str = 'NoCompression',
        **kwargs,
    ):
        try:
            if compression.lower() not in GRPC_COMPRESSION_MAP:
                import warnings

                warnings.warn(
                    message=
                    f'Your compression "{compression}" is not supported. Supported '
                    f'algorithms are `Gzip`, `Deflate` and `NoCompression`. NoCompression will be used as '
                    f'default')
                compression = 'NoCompression'
            self.inputs = inputs
            req_iter = self._get_requests(**kwargs)
            async with GrpcConnectionPool.get_grpc_channel(
                    f'{self.args.host}:{self.args.port}',
                    asyncio=True,
                    tls=self.args.tls,
            ) as channel:
                stub = jina_pb2_grpc.JinaRPCStub(channel)
                self.logger.debug(
                    f'connected to {self.args.host}:{self.args.port}')

                with ProgressBar(total_length=self._inputs_length,
                                 disable=not (self.show_progress)) as p_bar:

                    async for resp in stub.Call(
                            req_iter,
                            compression=GRPC_COMPRESSION_MAP.get(
                                compression.lower(),
                                grpc.Compression.NoCompression),
                    ):
                        callback_exec(
                            response=resp,
                            on_error=on_error,
                            on_done=on_done,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                        if self.show_progress:
                            p_bar.update()
                        yield resp

        except KeyboardInterrupt:
            self.logger.warning('user cancel the process')
        except asyncio.CancelledError as ex:
            self.logger.warning(f'process error: {ex!r}')
        except grpc.aio._call.AioRpcError as rpc_ex:
            # Since this object is guaranteed to be a grpc.Call, might as well include that in its name.
            my_code = rpc_ex.code()
            my_details = rpc_ex.details()
            msg = f'gRPC error: {my_code} {my_details}'

            try:
                if my_code == grpc.StatusCode.UNAVAILABLE:
                    self.logger.error(
                        f'{msg}\nthe ongoing request is terminated as the server is not available or closed already'
                    )
                    raise rpc_ex
                elif my_code == grpc.StatusCode.INTERNAL:
                    self.logger.error(
                        f'{msg}\ninternal error on the server side')
                    raise rpc_ex
                elif (my_code == grpc.StatusCode.UNKNOWN
                      and 'asyncio.exceptions.TimeoutError' in my_details):
                    raise BadClientInput(
                        f'{msg}\n'
                        'often the case is that you define/send a bad input iterator to jina, '
                        'please double check your input iterator') from rpc_ex
                else:
                    raise BadClient(msg) from rpc_ex

            except (
                    grpc.aio._call.AioRpcError,
                    BaseJinaException,
            ) as e:  # depending on if there are callbacks we catch or not the exception
                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e