Exemplo n.º 1
0
def test_get_dataset_hashes_generated_number(client_task_definition_path,
                                             tmpdir):
    client = Client()

    client._worker_output_directory = str(tmpdir)
    client.load_training_request_data(client_task_definition_path)
    assert (len(client.get_dataset_hashes()) == NUMBER_OF_DATASET_SEGMENTS)
Exemplo n.º 2
0
def test_get_dataset_hashes_are_segment_files_generated(
        client_task_definition_path, tmpdir):
    client = Client()

    client._worker_output_directory = str(tmpdir)
    client.load_training_request_data(client_task_definition_path)
    hash_data = client.get_dataset_hashes()
    for segment_path in hash_data.values():
        assert os.path.isfile(segment_path)
Exemplo n.º 3
0
def test_send_dataset_hashes_message_format_testing(
        redisdb, client_task_definition_data, tmpdir):
    client = Client()
    client.conn = redisdb
    client._cluster_request_data = client_task_definition_data
    client._cluster_address = 'test_cluster_address'
    client._worker_output_directory = str(tmpdir)

    client.send_dataset_hashes()

    message = redisdb.lpop('test_cluster_address')
    message_data = yaml.load(message, yaml.UnsafeLoader)

    assert message_data['client_id'] == client._client_id
    assert len(message_data['hashes']) == NUMBER_OF_DATASET_SEGMENTS