def test_get_cluster_response_valid(redisdb): client = Client() client.conn = redisdb redisdb.lpush(client._client_listen_address, 'test') assert client.get_cluster_response() == 'test'
def test_send_data_segments_to_client_response_message_order( mocker, tmpdir, redisdb, client_task_definition_data): segment0 = tmpdir.join('test0.segment') segment0.write('content') segment1 = tmpdir.join('test1.segment') segment1.write('content') segment2 = tmpdir.join('test2.segment') segment2.write('content') client = Client() client.conn = redisdb client._cluster_request_data = client_task_definition_data client._cluster_address = str(uuid.uuid4()) client._dataset_segments = { 'test0': str(segment0), 'test1': str(segment1), 'test2': str(segment2) } client.segment_hashes = ['test1', 'test2', 'test0'] client.send_data_segments_to_cluster() message_raw = redisdb.lrange(client._cluster_address, 0, -1) assert message_raw is not None messages = list( map(functools.partial(yaml.load, Loader=yaml.UnsafeLoader), message_raw)) for key, message in zip(client.segment_hashes, messages): assert message['hash'] == key
def test_get_cluster_response_waiting(redisdb, monkeypatch): client = Client() client.conn = redisdb counter = [0] def send_response(_): counter[0] += 1 if counter[0] > 10: redisdb.lpush(client._client_listen_address, 'test') monkeypatch.setattr(pai.pouw.nodes.decentralized.client.time, 'sleep', send_response) assert client.get_cluster_response() == 'test'
def test_send_dataset_hashes_message_format_testing( redisdb, client_task_definition_data, tmpdir): client = Client() client.conn = redisdb client._cluster_request_data = client_task_definition_data client._cluster_address = 'test_cluster_address' client._worker_output_directory = str(tmpdir) client.send_dataset_hashes() message = redisdb.lpop('test_cluster_address') message_data = yaml.load(message, yaml.UnsafeLoader) assert message_data['client_id'] == client._client_id assert len(message_data['hashes']) == NUMBER_OF_DATASET_SEGMENTS
def test_send_data_segments_to_client_cleanup_procedure( mocker, tmpdir, redisdb, client_task_definition_data): segment = tmpdir.join('test.segment') segment.write('content') client = Client() client.conn = redisdb client._cluster_request_data = client_task_definition_data client._cluster_address = str(uuid.uuid4()) client._dataset_segments = {'test': str(segment)} client.segment_hashes = ['test'] client.send_data_segments_to_cluster() assert not os.path.isfile(str(segment))
def test_inform_client_of_task_integrated(client_task_definition_data, redisdb): client_address = 'test_client' node = CommitteeCandidate() node.conn = redisdb node.task_id = '123' node.task_data = client_task_definition_data node.task_data['client_listen_address'] = client_address node._client_response_listening_channel = 'test_cluster' node.inform_client_of_task_id() client = Client() client._client_id = client_task_definition_data['client_id'] client.conn = redisdb client._client_listen_address = client_address client.obtain_cluster_task_id() assert client._task_id == '123' assert client._cluster_address == 'test_cluster'
def test_send_data_segments_to_client_response_message( mocker, tmpdir, redisdb, client_task_definition_data): segment = tmpdir.join('test.segment') segment.write('content') client = Client() client.conn = redisdb client._cluster_request_data = client_task_definition_data client._cluster_address = str(uuid.uuid4()) client._dataset_segments = {'test': str(segment)} client.segment_hashes = ['test'] client.send_data_segments_to_cluster() message_raw = redisdb.lrange(client._cluster_address, 0, -1) assert message_raw is not None messages = list( map(functools.partial(yaml.load, Loader=yaml.UnsafeLoader), message_raw)) message = messages[0] assert all(key in message for key in ['hash', 'bucket', 'key'])
def test_send_initial_training_request(client_task_definition_path, redisdb): client = Client() client.conn = redisdb client.send_initial_training_request(client_task_definition_path) assert len(redisdb.lrange(CLIENT_TASK_CHANNEL, 0, -1)) == 1