def external_deployment_shards_2_args(num_shards):
    args = [
        '--uses',
        'MyExternalExecutor',
        '--name',
        'external_real_2',
        '--port',
        str(random_port()),
        '--shards',
        str(num_shards),
        '--polling',
        'all',
    ]
    return set_deployment_parser().parse_args(args)
Exemple #2
0
def test_disable_prefetch_slow_client_fast_executor(protocol, inputs,
                                                    monkeypatch,
                                                    simple_graph_dict_fast):
    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyMockConnectionPool.send_requests_once,
    )
    port_in = random_port()

    p = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': protocol,
            'port_in': port_in,
            'graph_dict': simple_graph_dict_fast,
        },
    )
    p.start()
    time.sleep(1.0)

    final_da = DocumentArray()

    client = Client(protocol=protocol, port=port_in)
    client.post(
        on='/',
        inputs=inputs,
        request_size=1,
        on_done=lambda response: on_done(response, final_da),
    )
    p.terminate()
    p.join()
    assert len(final_da) == INPUT_LEN
    # Since the input_gen is slow, order will always be gen -> exec -> on_done for every request
    assert final_da['id-0'].tags['input_gen'] < final_da['id-0'].tags[
        'executor']
    assert final_da['id-0'].tags['executor'] < final_da['id-0'].tags['on_done']
    assert final_da['id-0'].tags['on_done'] < final_da['id-1'].tags['input_gen']
    assert final_da['id-1'].tags['input_gen'] < final_da['id-1'].tags[
        'executor']
    assert final_da['id-1'].tags['executor'] < final_da['id-1'].tags['on_done']
    assert final_da['id-1'].tags['on_done'] < final_da['id-2'].tags['input_gen']
    assert final_da['id-2'].tags['input_gen'] < final_da['id-2'].tags[
        'executor']
    assert final_da['id-2'].tags['executor'] < final_da['id-2'].tags['on_done']
    assert final_da['id-2'].tags['on_done'] < final_da['id-3'].tags['input_gen']
    assert final_da['id-3'].tags['input_gen'] < final_da['id-3'].tags[
        'executor']
    assert final_da['id-3'].tags['executor'] < final_da['id-3'].tags['on_done']
def test_top_k_with_rest_api(query_dict):
    port = helper.random_port()
    with Flow(rest_api=True, port_expose=port).add():
        # temporarily adding sleep
        time.sleep(0.5)
        query = json.dumps(query_dict).encode('utf-8')
        req = request.Request(
            f'http://0.0.0.0:{port}/api/search',
            data=query,
            headers={'content-type': 'application/json'},
        )
        resp = request.urlopen(req).read().decode('utf8')
        assert json.loads(resp)['queryset'][0]['name'] == 'SliceQL'
        assert json.loads(resp)['queryset'][0]['parameters']['end'] == TOP_K
        assert json.loads(resp)['queryset'][0]['priority'] == 1
def external_executor_args():
    args = [
        '--uses',
        'MyExternalExecutor',
        '--name',
        'external_real',
        '--port-in',
        str(random_port()),
        '--host-in',
        '0.0.0.0',
        '--socket-in',
        'ROUTER_BIND',
        '--dynamic-routing-out',
    ]
    return set_pea_parser().parse_args(args)
Exemple #5
0
def external_pod_args(num_replicas, num_parallel):
    args = [
        '--uses',
        'MyExternalExecutor',
        '--name',
        'external_real',
        '--port-in',
        str(random_port()),
        '--host-in',
        '0.0.0.0',
        '--parallel',
        str(num_parallel),
        '--replicas',
        str(num_replicas),
    ]
    return set_pod_parser().parse_args(args)
Exemple #6
0
def test_tag_update(grpc_data_requests):
    PORT_EXPOSE = random_port()

    f = Flow(
        port_expose=PORT_EXPOSE,
        protocol='http',
        grpc_data_requests=grpc_data_requests,
    ).add(uses=TestExecutor)

    with f:
        d1 = {"data": [{"id": "1", "prop1": "val"}]}
        d2 = {"data": [{"id": "2", "prop2": "val"}]}
        r1 = req.post(f'http://localhost:{PORT_EXPOSE}/index', json=d1)
        assert r1.json()['data']['docs'][0]['tags'] == {'prop1': 'val'}
        r2 = req.post(f'http://localhost:{PORT_EXPOSE}/index', json=d2)
        assert r2.json()['data']['docs'][0]['tags'] == {'prop2': 'val'}
def test_grpc_gateway_runtime_handle_messages_complete_graph_dict(
    complete_graph_dict, monkeypatch, protocol
):
    # TODO: Test incomplete until merging of responses is ready
    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyMockConnectionPool.send_requests_once,
    )
    port = random_port()

    def client_validate(client_id: int):
        responses = client_send(client_id, port, protocol)
        assert len(responses) > 0
        assert len(responses[0].docs) == 1
        # there are 3 incoming paths to merger, it could be any
        assert (
            f'client{client_id}-Request-client{client_id}-deployment0-client{client_id}-deployment1-client{client_id}-merger-client{client_id}-deployment_last'
            == responses[0].docs[0].text
            or f'client{client_id}-Request-client{client_id}-deployment0-client{client_id}-deployment2-client{client_id}-deployment3-client{client_id}-merger-client{client_id}-deployment_last'
            == responses[0].docs[0].text
            or f'client{client_id}-Request-client{client_id}-deployment4-client{client_id}-deployment5-client{client_id}-merger-client{client_id}-deployment_last'
            == responses[0].docs[0].text
        )

    p = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': protocol,
            'port': port,
            'graph_dict': complete_graph_dict,
        },
    )
    p.start()
    time.sleep(1.0)
    client_processes = []
    for i in range(NUM_PARALLEL_CLIENTS):
        cp = multiprocessing.Process(target=client_validate, kwargs={'client_id': i})
        cp.start()
        client_processes.append(cp)

    for cp in client_processes:
        cp.join()
    p.terminate()
    p.join()
    for cp in client_processes:
        assert cp.exitcode == 0
def test_flowstore_update(partial_flow_store, mocker):
    flow_model = FlowModel()
    flow_model.uses = f'{cur_dir}/flow.yml'
    port_expose = helper.random_port()
    args = ArgNamespace.kwargs2namespace(flow_model.dict(), set_flow_parser())

    partial_flow_store.add(args, port_expose)

    update_mock = mocker.Mock()
    partial_flow_store.object.rolling_update = update_mock

    partial_flow_store.update(kind=UpdateOperation.ROLLING_UPDATE,
                              dump_path='',
                              pod_name='pod1',
                              shards=1)

    update_mock.assert_called()
    def port_expose(self) -> str:
        """
        Sets `port_expose` for the Flow started in `partial-daemon`.
        This port needs to be exposed before starting `partial-daemon`, hence set here.
        If env vars are passed, set them in current env, to make sure all values are replaced.
        Before loading the flow yaml, change CWD to workspace dir.

        :return: port_expose
        """
        with ExitStack() as stack:
            if self.envs:
                for key, val in self.envs.items():
                    stack.enter_context(change_env(key, val))
            stack.enter_context(
                change_cwd(get_workspace_path(self.workspace_id)))
            f = Flow.load_config(str(self.localpath()))
            return f.port_expose or random_port()
def test_no_matches_rest(query_dict):
    port = helper.random_port()
    with Flow(rest_api=True, port_expose=port).add(uses='!MockExecutor'):
        # temporarily adding sleep
        time.sleep(0.5)
        query = json.dumps(query_dict).encode('utf-8')
        req = request.Request(
            f'http://0.0.0.0:{port}/search',
            data=query,
            headers={'content-type': 'application/json'},
        )
        resp = request.urlopen(req).read().decode('utf8')
        doc = json.loads(resp)['search']['docs'][0]
        present_keys = sorted(doc.keys())
        for field in _document_fields:
            if field not in IGNORED_FIELDS + ['buffer', 'content', 'blob']:
                assert field in present_keys
Exemple #11
0
def test_grpc_gateway_runtime_handle_messages_merge_in_gateway(
        merge_graph_dict_directly_merge_in_gateway, monkeypatch, protocol):
    # TODO: Test incomplete until merging of responses is ready
    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyMockConnectionPool.send_requests_once,
    )
    port_in = random_port()

    def client_validate(client_id: int):
        responses = client_send(client_id, port_in, protocol)
        assert len(responses) > 0
        assert len(responses[0].docs) == 1
        deployment1_path = (
            f'client{client_id}-Request-client{client_id}-deployment0-client{client_id}-deployment1-client{client_id}-merger'
            in responses[0].docs[0].text)
        deployment2_path = (
            f'client{client_id}-Request-client{client_id}-deployment0-client{client_id}-deployment2-client{client_id}-merger'
            in responses[0].docs[0].text)
        assert deployment1_path or deployment2_path

    p = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': protocol,
            'port_in': port_in,
            'graph_dict': merge_graph_dict_directly_merge_in_gateway,
        },
    )
    p.start()
    time.sleep(1.0)
    client_processes = []
    for i in range(NUM_PARALLEL_CLIENTS):
        cp = multiprocessing.Process(target=client_validate,
                                     kwargs={'client_id': i})
        cp.start()
        client_processes.append(cp)

    for cp in client_processes:
        cp.join()
    p.terminate()
    p.join()
    for cp in client_processes:
        assert cp.exitcode == 0
def external_deployment_join_args(num_replicas, num_shards):
    args = [
        '--uses',
        'MyExternalExecutor',
        '--name',
        'external_real',
        '--port-in',
        str(random_port()),
        '--deployment-role',
        'JOIN',
        '--shards',
        str(num_shards),
        '--replicas',
        str(num_replicas),
        '--polling',
        'all',
    ]
    return set_deployment_parser().parse_args(args)
Exemple #13
0
def external_pod_pre_shards_args(num_replicas, num_shards):
    args = [
        '--uses',
        'MyExternalExecutor',
        '--name',
        'external_real',
        '--port-in',
        str(random_port()),
        '--host-in',
        '0.0.0.0',
        '--shards',
        str(num_shards),
        '--replicas',
        str(num_replicas),
        '--polling',
        'all',
    ]
    return set_pod_parser().parse_args(args)
Exemple #14
0
def test_grpc_gateway_runtime_handle_messages_bifurcation(
        bifurcation_graph_dict, monkeypatch, protocol):
    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyMockConnectionPool.send_requests_once,
    )
    port = random_port()

    def client_validate(client_id: int):
        responses = client_send(client_id, port, protocol)
        assert len(responses) > 0
        # reducing is supposed to happen in the deployments, in the test it will get a single doc in non deterministic order
        assert len(responses[0].docs) == 1
        assert (
            responses[0].docs[0].text ==
            f'client{client_id}-Request-client{client_id}-deployment0-client{client_id}-deployment2-client{client_id}-deployment3'
            or responses[0].docs[0].text ==
            f'client{client_id}-Request-client{client_id}-deployment4-client{client_id}-deployment5'
        )

    p = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': protocol,
            'port': port,
            'graph_dict': bifurcation_graph_dict,
        },
    )
    p.start()
    time.sleep(1.0)
    client_processes = []
    for i in range(NUM_PARALLEL_CLIENTS):
        cp = multiprocessing.Process(target=client_validate,
                                     kwargs={'client_id': i})
        cp.start()
        client_processes.append(cp)

    for cp in client_processes:
        cp.join()
    p.terminate()
    p.join()
    for cp in client_processes:
        assert cp.exitcode == 0
Exemple #15
0
        async def _scale_up(self, replicas: int):
            new_pods = []
            new_args_list = []
            for i in range(len(self._pods), replicas):
                new_args = copy.copy(self.args[0])
                new_args.noblock_on_start = True
                new_args.name = new_args.name[:-1] + f'{i}'
                new_args.port_in = helper.random_port()
                # no exception should happen at create and enter time
                new_pods.append(PodFactory.build_pod(new_args).start())
                new_args_list.append(new_args)
            exception = None
            for new_pod, new_args in zip(new_pods, new_args_list):
                try:
                    await new_pod.async_wait_start_success()
                    await GrpcConnectionPool.activate_worker(
                        worker_host=Deployment.get_worker_host(
                            new_args, new_pod, self.head_pod),
                        worker_port=new_args.port_in,
                        target_head=
                        f'{self.head_pod.args.host}:{self.head_pod.args.port_in}',
                        shard_id=self.shard_id,
                    )
                except (
                        RuntimeFailToStart,
                        TimeoutError,
                        RuntimeRunForeverEarlyError,
                ) as ex:
                    exception = ex
                    break

            if exception is not None:
                if self.deployment_args.shards > 1:
                    msg = f' Scaling fails for shard {self.deployment_args.shard_id}'
                else:
                    msg = ' Scaling fails'

                msg += f'due to executor failing to start with exception: {exception!r}'
                raise ScalingFails(msg)
            else:
                for new_pod, new_args in zip(new_pods, new_args_list):
                    self.args.append(new_args)
                    self._pods.append(new_pod)
Exemple #16
0
def test_index_remote_rpi(test_workspace):
    f_args = set_gateway_parser().parse_args(['--host', '0.0.0.0'])

    def start_gateway():
        with Pod(f_args):
            time.sleep(3)

    t = mp.Process(target=start_gateway)
    t.daemon = True
    t.start()

    f = Flow(optimize_level=FlowOptimizeLevel.IGNORE_GATEWAY).add(
        uses=os.path.join(cur_dir, 'yaml/test-index-remote.yml'),
        parallel=3,
        host='0.0.0.0',
        port_expose=random_port())

    with f:
        f.index(input_fn=random_docs(1000))
Exemple #17
0
def test_grpc_gateway_runtime_lazy_request_access(linear_graph_dict,
                                                  monkeypatch):
    call_counts = multiprocessing.Queue()

    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyNoDocAccessMockConnectionPool.send_requests_once,
    )
    port = random_port()

    def client_validate(client_id: int):
        responses = client_send(client_id, port, 'grpc')
        assert len(responses) > 0

    p = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': 'grpc',
            'port': port,
            'graph_dict': linear_graph_dict,
            'call_counts': call_counts,
            'monkeypatch': monkeypatch,
        },
    )
    p.start()
    time.sleep(1.0)
    client_processes = []
    for i in range(NUM_PARALLEL_CLIENTS):
        cp = multiprocessing.Process(target=client_validate,
                                     kwargs={'client_id': i})
        cp.start()
        client_processes.append(cp)

    for cp in client_processes:
        cp.join()
    p.terminate()
    p.join()
    assert (_queue_length(call_counts) == NUM_PARALLEL_CLIENTS * 2
            )  # request should be decompressed at start and end only
    for cp in client_processes:
        assert cp.exitcode == 0
Exemple #18
0
def test_no_matches_rest(query_dict):
    port = helper.random_port()
    with Flow(
            protocol='http',
            port=port,
            including_default_value_fields=True,
    ).add(uses=MockExecutor):
        # temporarily adding sleep
        time.sleep(0.5)
        query = json.dumps(query_dict).encode('utf-8')
        req = request.Request(
            f'http://localhost:{port}/search',
            data=query,
            headers={'content-type': 'application/json'},
        )
        resp = request.urlopen(req).read().decode('utf8')
        doc = json.loads(resp)['data'][0]

    assert len(Document.from_dict(doc).matches) == 0
    assert Document.from_dict(doc).tags['tag'] == 'test'
Exemple #19
0
def test_grpc_gateway_runtime_handle_messages_linear(linear_graph_dict,
                                                     monkeypatch, protocol):
    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyMockConnectionPool.send_requests_once,
    )
    port_in = random_port()

    def client_validate(client_id: int):
        responses = client_send(client_id, port_in, protocol)
        assert len(responses) > 0
        assert len(responses[0].docs) == 1
        assert (
            responses[0].docs[0].text ==
            f'client{client_id}-Request-client{client_id}-deployment0-client{client_id}-deployment1-client{client_id}-deployment2-client{client_id}-deployment3'
        )

    p = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': protocol,
            'port_in': port_in,
            'graph_dict': linear_graph_dict,
        },
    )
    p.start()
    time.sleep(1.0)
    client_processes = []
    for i in range(NUM_PARALLEL_CLIENTS):
        cp = multiprocessing.Process(target=client_validate,
                                     kwargs={'client_id': i})
        cp.start()
        client_processes.append(cp)

    for cp in client_processes:
        cp.join()
    p.terminate()
    p.join()
    for cp in client_processes:
        assert cp.exitcode == 0
Exemple #20
0
def mixin_client_gateway_parser(parser):
    """Add the options for the client connecting to the Gateway
    :param parser: the parser
    """
    gp = add_arg_group(parser, title='ClientGateway')
    _add_host(gp)
    _add_proxy(gp)

    gp.add_argument(
        '--port',
        type=int,
        default=helper.random_port(),
        help='The port of the Gateway, which the client should connect to.',
    )

    gp.add_argument(
        '--https',
        action='store_true',
        default=False,
        help='If set, connect to gateway using https',
    )
def test_flow_with_external_deployment_shards(
    external_deployment_shards_1,
    external_deployment_shards_2,
    external_deployment_shards_1_args,
    external_deployment_shards_2_args,
    input_docs,
    num_shards,
):
    with external_deployment_shards_1, external_deployment_shards_2:
        external_args_1 = vars(external_deployment_shards_1_args)
        external_args_2 = vars(external_deployment_shards_2_args)
        del external_args_1['name']
        del external_args_1['external']
        del external_args_1['deployment_role']
        del external_args_2['name']
        del external_args_2['external']
        del external_args_2['deployment_role']
        flow = (
            Flow()
            .add(name='executor1')
            .add(
                **external_args_1,
                name='external_fake_1',
                external=True,
                needs=['executor1'],
            )
            .add(
                **external_args_2,
                name='external_fake_2',
                external=True,
                needs=['executor1'],
            )
            .needs(needs=['external_fake_1', 'external_fake_2'], port=random_port())
        )

        with flow:
            resp = flow.index(inputs=input_docs)

        # Reducing applied on shards and needs, expect 50 docs
        validate_response(resp, 50)
    def process_wrapper():
        monkeypatch.setattr(
            networking.GrpcConnectionPool,
            'send_requests_once',
            DummyMockConnectionPool.send_requests_once,
        )
        port_in = random_port()

        with GRPCGatewayRuntime(set_gateway_parser().parse_args([
                '--port-expose',
                f'{port_in}',
                '--graph-description',
                f'{json.dumps(complete_graph_dict)}',
                '--deployments-addresses',
                '{}',
        ])) as runtime:

            async def _test():
                responses = []
                req = request_generator(
                    '/', DocumentArray([Document(text='client0-Request')]))
                async for resp in runtime.streamer.Call(request_iterator=req):
                    responses.append(resp)
                return responses

            responses = asyncio.run(_test())
        assert len(responses) > 0
        assert len(responses[0].docs) == 1
        deployment2_path = (
            f'client0-Request-client0-deployment0-client0-deployment2-client0-deployment3-client0-merger-client0-deployment_last'
            == responses[0].docs[0].text)
        deployment4_path = (
            f'client0-Request-client0-deployment4-client0-deployment5-client0-merger-client0-deployment_last'
            == responses[0].docs[0].text)
        assert (
            f'client0-Request-client0-deployment0-client0-deployment1-client0-merger-client0-deployment_last'
            == responses[0].docs[0].text or deployment2_path
            or deployment4_path)
    def process_wrapper():
        port_in = random_port()

        with GRPCGatewayRuntime(set_gateway_parser().parse_args([
                '--port-expose',
                f'{port_in}',
                '--graph-description',
                f'{json.dumps({})}',
                '--deployments-addresses',
                '{}',
        ])) as runtime:

            async def _test():
                responses = []
                req = request_generator(
                    '/', DocumentArray([Document(text='client0-Request')]))
                async for resp in runtime.streamer.Call(request_iterator=req):
                    responses.append(resp)
                return responses

            responses = asyncio.run(_test())
        assert len(responses) > 0
        assert len(responses[0].docs) == 1
        assert responses[0].docs[0].text == f'client0-Request'
def test_flow_with_external_pod_parallel(
    external_pod_parallel_1,
    external_pod_parallel_2,
    external_pod_parallel_1_args,
    external_pod_parallel_2_args,
    input_docs,
    num_replicas,
    num_parallel,
):
    with external_pod_parallel_1, external_pod_parallel_2:
        external_args_1 = vars(external_pod_parallel_1_args)
        external_args_2 = vars(external_pod_parallel_2_args)
        del external_args_1['name']
        del external_args_1['external']
        del external_args_1['pod_role']
        del external_args_1['dynamic_routing']
        del external_args_2['name']
        del external_args_2['external']
        del external_args_2['pod_role']
        del external_args_2['dynamic_routing']
        flow = (Flow().add(name='pod1').add(
            **external_args_1,
            name='external_fake_1',
            external=True,
            needs=['pod1'],
        ).add(
            **external_args_2,
            name='external_fake_2',
            external=True,
            needs=['pod1'],
        ).join(needs=['external_fake_1', 'external_fake_2'],
               port_in=random_port()))

        with flow:
            resp = flow.index(inputs=input_docs, return_results=True)
        validate_response(resp[0], 50 * num_parallel * 2)
Exemple #25
0
    [iter([b'1234', b'45467']),
     iter([DocumentProto(), DocumentProto()])])
def test_check_input_success(inputs):
    Client.check_input(inputs)


@pytest.mark.parametrize(
    'inputs', [iter([list(), list(), [12, 2, 3]]),
               iter([set(), set()])])
def test_check_input_fail(inputs):
    with pytest.raises(BadClientInput):
        Client.check_input(inputs)


@pytest.mark.parametrize('port_expose, route, status_code',
                         [(random_port(), '/status', 200),
                          (random_port(), '/api/ass', 405)])
def test_gateway_ready(port_expose, route, status_code):
    p = set_gateway_parser().parse_args(
        ['--port-expose',
         str(port_expose), '--runtime-cls', 'RESTRuntime'])
    with Pea(p):
        time.sleep(0.5)
        a = requests.get(f'http://0.0.0.0:{p.port_expose}{route}')
        assert a.status_code == status_code


def test_gateway_index(flow_with_rest_api_enabled, test_img_1, test_img_2):
    with flow_with_rest_api_enabled:
        time.sleep(0.5)
        r = requests.post(
Exemple #26
0
def test_helloworld_py_chatbot(tmpdir):
    from jina.helloworld.chatbot import hello_world
    hello_world(set_hw_chatbot_parser().parse_args(['--workdir', str(tmpdir),
                                                    '--unblock-query-flow',
                                                    '--port-expose', str(random_port())]))
Exemple #27
0
     iter([DocumentProto(), DocumentProto()])])
def test_check_input_success(input_fn):
    PyClient.check_input(input_fn)


@pytest.mark.parametrize(
    'input_fn',
    [iter([b'1234', '45467', [12, 2, 3]]),
     iter([DocumentProto(), None])])
def test_check_input_fail(input_fn):
    with pytest.raises(TypeError):
        PyClient.check_input(input_fn)


@pytest.mark.parametrize('port_expose, route, status_code',
                         [(random_port(), '/ready', 200),
                          (random_port(), '/api/ass', 405)])
def test_gateway_ready(port_expose, route, status_code):
    p = set_gateway_parser().parse_args(['--port-expose', str(port_expose)])
    with RESTGatewayPea(p):
        a = requests.get(f'http://0.0.0.0:{p.port_expose}{route}')
        assert a.status_code == status_code


def test_gateway_index(flow_with_rest_api_enabled, test_img_1, test_img_2):
    with flow_with_rest_api_enabled:
        r = requests.post(
            f'http://0.0.0.0:{flow_with_rest_api_enabled.port_expose}/api/index',
            json={'data': [test_img_1, test_img_2]},
        )
        assert r.status_code == 200
Exemple #28
0
def test_random_port(config):
    reset_ports()
    assert os.environ['JINA_RANDOM_PORT_MIN']
    port = random_port()
    assert 49153 <= port <= 65535
Exemple #29
0
def test_multiple_clients(prefetch, protocol, monkeypatch,
                          simple_graph_dict_indexer):
    GOOD_CLIENTS = 5
    GOOD_CLIENT_NUM_DOCS = 20
    MALICIOUS_CLIENT_NUM_DOCS = 50

    def get_document(i):
        return Document(
            id=f'{multiprocessing.current_process().name}_{i}',
            text=str(bytes(bytearray(os.urandom(512 * 4)))),
        )

    async def good_client_gen():
        for i in range(GOOD_CLIENT_NUM_DOCS):
            yield get_document(i)
            await asyncio.sleep(0.1)

    async def malicious_client_gen():
        for i in range(1000, 1000 + MALICIOUS_CLIENT_NUM_DOCS):
            yield get_document(i)

    def client(gen, port, protocol):
        Client(protocol=protocol, port=port).post(on='/index',
                                                  inputs=gen,
                                                  request_size=1)

    monkeypatch.setattr(
        networking.GrpcConnectionPool,
        'send_requests_once',
        DummyMockConnectionPool.send_requests_once,
    )
    port_in = random_port()

    pool = []
    runtime_process = multiprocessing.Process(
        target=create_runtime,
        kwargs={
            'protocol': protocol,
            'port_in': port_in,
            'graph_dict': simple_graph_dict_indexer,
            'prefetch': prefetch,
        },
    )
    runtime_process.start()
    time.sleep(1.0)
    # We have 5 good clients connecting to the same gateway. They have controlled requests.
    # Each client sends `GOOD_CLIENT_NUM_DOCS` (20) requests and sleeps after each request.
    for i in range(GOOD_CLIENTS):
        cp = multiprocessing.Process(
            target=partial(client, good_client_gen, port_in, protocol),
            name=f'goodguy_{i}',
        )
        cp.start()
        pool.append(cp)

    # and 1 malicious client, sending lot of requests (trying to block others)
    cp = multiprocessing.Process(
        target=partial(client, malicious_client_gen, port_in, protocol),
        name='badguy',
    )
    cp.start()
    pool.append(cp)

    for p in pool:
        p.join()

    order_of_ids = list(
        Client(protocol=protocol,
               port=port_in).post(on='/status',
                                  inputs=[Document()],
                                  return_results=True)[0].docs[0].tags['ids'])
    # There must be total 150 docs indexed.

    runtime_process.terminate()
    runtime_process.join()
    assert (len(order_of_ids) == GOOD_CLIENTS * GOOD_CLIENT_NUM_DOCS +
            MALICIOUS_CLIENT_NUM_DOCS)
    """
    If prefetch is set, each Client is allowed (max) 5 requests at a time.
    Since requests are controlled, `badguy` has to do the last 20 requests.

    If prefetch is disabled, clients can freeflow requests. No client is blocked.
    Hence last 20 requests go from `goodguy`.
    (Ideally last 30 requests should be validated, to avoid flaky CI, we test last 20)

    When there are no rules, badguy wins! With rule, you find balance in the world.
    """
    if protocol == 'http':
        # There's no prefetch for http.
        assert set(map(lambda x: x.split('_')[0],
                       order_of_ids[-20:])) == {'goodguy'}
    elif prefetch == 5:
        assert set(map(lambda x: x.split('_')[0],
                       order_of_ids[-20:])) == {'badguy'}
    elif prefetch == 0:
        assert set(map(lambda x: x.split('_')[0],
                       order_of_ids[-20:])) == {'goodguy'}
Exemple #30
0
    async def add(
        self,
        id: DaemonID,
        workspace_id: DaemonID,
        params: 'BaseModel',
        ports: Dict,
        envs: Dict[str, str] = {},
        **kwargs,
    ) -> DaemonID:
        """Add a container to the store

        :param id: id of the container
        :param workspace_id: workspace id where the container lives
        :param params: pydantic model representing the args for the container
        :param ports: ports to be mapped to local
        :param envs: dict of env vars to be passed
        :param kwargs: keyword args
        :raises KeyError: if workspace_id doesn't exist in the store
        :raises PartialDaemonConnectionException: if jinad cannot connect to partial
        :return: id of the container
        """
        try:
            from . import workspace_store

            if workspace_id not in workspace_store:
                raise KeyError(f'{workspace_id} not found in workspace store')

            minid_port = random_port()
            ports.update({f'{minid_port}/tcp': minid_port})
            uri = self._uri(minid_port)
            command = self._command(minid_port, workspace_id)
            params = params.dict(exclude={'log_config'})

            self._logger.debug(
                'creating container with following arguments \n' + '\n'.join([
                    '{:15s} -> {:15s}'.format('id', id),
                    '{:15s} -> {:15s}'.format('workspace', workspace_id),
                    '{:15s} -> {:15s}'.format('ports', str(ports)),
                    '{:15s} -> {:15s}'.format('command', command),
                ]))

            container, network, ports = Dockerizer.run(
                workspace_id=workspace_id,
                container_id=id,
                command=command,
                ports=ports,
                envs=envs,
            )
            if not await self.ready(uri):
                raise PartialDaemonConnectionException(
                    f'{id.type.title()} creation failed, couldn\'t reach the container at {uri} after 10secs'
                )
            object = await self._add(uri=uri, params=params, **kwargs)
        except Exception as e:
            self._logger.error(f'{self._kind} creation failed as {e}')
            container_logs = Dockerizer.logs(container.id)
            if container_logs and isinstance(
                    e,
                (PartialDaemon400Exception, PartialDaemonConnectionException)):
                self._logger.debug(
                    f'error logs from partial daemon: \n {container_logs}')
                if e.message and isinstance(e.message, list):
                    e.message += container_logs.split('\n')
                elif e.message and isinstance(e.message, str):
                    e.message += container_logs
            if id in Dockerizer.containers:
                self._logger.info(
                    f'removing container {id_cleaner(container.id)}')
                Dockerizer.rm_container(container.id)
            raise
        else:
            self[id] = ContainerItem(
                metadata=ContainerMetadata(
                    container_id=id_cleaner(container.id),
                    container_name=container.name,
                    image_id=id_cleaner(container.image.id),
                    network=network,
                    ports=ports,
                    uri=uri,
                ),
                arguments=ContainerArguments(
                    command=command,
                    object=object,
                ),
                workspace_id=workspace_id,
            )
            self._logger.success(
                f'{colored(id, "green")} is added to workspace {colored(workspace_id, "green")}'
            )
            workspace_store[workspace_id].metadata.managed_objects.add(id)
            return id