def run_sagemaker_remote_tests(images): """ Function to set up multiprocessing for SageMaker tests :param images: <list> List of all images to be used in SageMaker tests """ use_scheduler = os.getenv("USE_SCHEDULER", "False").lower() == "true" executor_mode = os.getenv("EXECUTOR_MODE", "False").lower() == "true" if executor_mode: LOGGER.info("entered executor mode.") import log_return num_of_instances = os.getenv("NUM_INSTANCES") image = os.getenv("DLC_IMAGE") job_type = "training" if "training" in image else "inference" instance_type = sm_utils.assign_sagemaker_remote_job_instance_type( image) test_succeeded = run_sagemaker_test_in_executor( image, num_of_instances, instance_type) tag = image.split("/")[-1].split(":")[-1] test_report = os.path.join(os.getcwd(), "test", f"{tag}.xml") # update in-progress pool, send the xml reports if test_succeeded: log_return.update_pool("completed", instance_type, num_of_instances, job_type, test_report) else: log_return.update_pool("runtimeError", instance_type, num_of_instances, job_type, test_report) return elif use_scheduler: LOGGER.info("entered scheduler mode.") import concurrent.futures from job_requester import JobRequester job_requester = JobRequester() with concurrent.futures.ThreadPoolExecutor( max_workers=len(images)) as executor: futures = [ executor.submit(send_scheduler_requests, job_requester, image) for image in images ] for future in futures: try: future.result() except Exception as e: LOGGER.error( f"An error occurred in one of the threads: {e}") else: if not images: return pool_number = len(images) with Pool(pool_number) as p: p.map(sm_utils.execute_sagemaker_remote_tests, images)
def run_sagemaker_remote_tests(images, pytest_cache_params): """ Function to set up multiprocessing for SageMaker tests :param images: <list> List of all images to be used in SageMaker tests """ use_scheduler = os.getenv("USE_SCHEDULER", "False").lower() == "true" executor_mode = os.getenv("EXECUTOR_MODE", "False").lower() == "true" if executor_mode: LOGGER.info("entered executor mode.") import log_return num_of_instances = os.getenv("NUM_INSTANCES") image = images[0] job_type = "training" if "training" in image else "inference" instance_type = sm_utils.assign_sagemaker_remote_job_instance_type(image) test_succeeded = run_sagemaker_test_in_executor(image, num_of_instances, instance_type) tag = image.split("/")[-1].split(":")[-1] test_report = os.path.join(os.getcwd(), "test", f"{tag}.xml") # update in-progress pool, send the xml reports if test_succeeded: log_return.update_pool("completed", instance_type, num_of_instances, job_type, test_report) else: log_return.update_pool("runtimeError", instance_type, num_of_instances, job_type, test_report) return elif use_scheduler: LOGGER.info("entered scheduler mode.") import concurrent.futures from job_requester import JobRequester job_requester = JobRequester() with concurrent.futures.ThreadPoolExecutor(max_workers=len(images)) as executor: futures = [executor.submit(send_scheduler_requests, job_requester, image) for image in images] for future in futures: future.result() else: if not images: return pool_number = len(images) # Using Manager().dict() since it's a thread safe dictionary global_pytest_cache = Manager().dict() try: with Pool(pool_number) as p: p.starmap( sm_utils.execute_sagemaker_remote_tests, [[i, images[i], global_pytest_cache, pytest_cache_params] for i in range(pool_number)] ) finally: pytest_cache_util.convert_cache_json_and_upload_to_s3(global_pytest_cache, **pytest_cache_params)
def main(): job_requester_object = JobRequester() request_ticket_prefix = f"testing-0_{REQUEST_TICKET_TIME}" # create identifier for the request ticket request_identifier = Message(SQS_RETURN_QUEUE_URL, BUCKET_NAME, f"{request_ticket_prefix}.json", TEST_ECR_URI, REQUEST_TICKET_TIME) test_query_and_cancel_queuing_tickets(job_requester_object, f"{request_ticket_prefix}.json", request_identifier) # naming convention of in-progress pool tickets: {request ticket name}#{num of instances}-{status}.json in_progress_ticket_name = f"{request_ticket_prefix}#1-running.json" test_query_in_progress_tickets(job_requester_object, in_progress_ticket_name, request_identifier) # naming convention of in-progress pool tickets: {request ticket name}-{failure reason}.json dead_letter_ticket_name = f"{request_ticket_prefix}-timeout.json" test_query_dead_letter_tickets(job_requester_object, dead_letter_ticket_name, request_identifier) LOGGER.info("Tests passed.")
def test_requester(): """ Tests the send_request and receive_logs functions of the Job Requester package. How tests are executed: - create one Job Requester object, and multiple threads. Perform send_request with the Job Requester object in each of these threads. - send messages to the SQS queue that the Job Requester object created, to imitate the response logs received back from the Job Executor. - In each of the threads, perform receive_logs to receive the log correspond to the send_request earlier. """ threads = 10 request_object = JobRequester() identifiers_list = [] input_list = [] # creating unique image names and build_context strings for _ in range(threads): input_list.append((TEST_IMAGE, "PR", 3)) # sending requests with concurrent.futures.ThreadPoolExecutor( max_workers=threads) as executor: futures = [ executor.submit(request_object.send_request, x, y, z) for (x, y, z) in input_list ] print("Created tickets......") for future in futures: res = future.result() print(res) identifiers_list.append(res) print("\n") # create sample xml report files image_tag = TEST_IMAGE.split(":")[-1] report_path = os.path.join(os.getcwd(), f"{image_tag}.xml") with open(report_path, "w") as report: report.write(SAMPLE_XML_MESSAGE) os.environ["CODEBUILD_BUILD_ARN"] = SAMPLE_CB_ARN for identifier in identifiers_list: os.environ["TICKET_KEY"] = f"folder/{identifier.ticket_name}" log_return.update_pool("completed", identifier.instance_type, 3, identifier.job_type, report_path) # receiving logs with concurrent.futures.ThreadPoolExecutor( max_workers=threads) as executor: logs = [ executor.submit(request_object.receive_logs, identifier) for identifier in identifiers_list ] LOGGER.info("Receiving logs...") for log in logs: assert "XML_REPORT" in log.result( ), f"XML Report not found as part of the returned log message." # clean up test artifacts S3 = boto3.client("s3") ticket_names = [item.ticket_name for item in identifiers_list] for name in ticket_names: S3.delete_object(Bucket=request_object.s3_ticket_bucket, Key=name) LOGGER.info("Tests passed.")