def run_sagemaker_test_in_executor(image, num_of_instances, instance_type): """ Run pytest in a virtual env for a particular image Expected to run under multi-threading :param num_of_instances: <int> number of instances the image test requires :param instance_type: type of sagemaker instance the test needs :param image: ECR url :return: """ import log_return LOGGER.info("Started running SageMaker test.....") pytest_command, path, tag, job_type = sm_utils.generate_sagemaker_pytest_cmd(image, "sagemaker") # update resource pool accordingly, then add a try-catch statement here to update the pool in case of failure try: log_return.update_pool("running", instance_type, num_of_instances, job_type) context = Context() with context.cd(path): context.run(f"python3 -m virtualenv {tag}") with context.prefix(f"source {tag}/bin/activate"): context.run("pip install -r requirements.txt", warn=True) context.run(pytest_command) except Exception as e: LOGGER.error(e) return False return True
def run_sagemaker_remote_tests(images): """ Function to set up multiprocessing for SageMaker tests :param images: <list> List of all images to be used in SageMaker tests """ use_scheduler = os.getenv("USE_SCHEDULER", "False").lower() == "true" executor_mode = os.getenv("EXECUTOR_MODE", "False").lower() == "true" if executor_mode: LOGGER.info("entered executor mode.") import log_return num_of_instances = os.getenv("NUM_INSTANCES") image = os.getenv("DLC_IMAGE") job_type = "training" if "training" in image else "inference" instance_type = sm_utils.assign_sagemaker_remote_job_instance_type( image) test_succeeded = run_sagemaker_test_in_executor( image, num_of_instances, instance_type) tag = image.split("/")[-1].split(":")[-1] test_report = os.path.join(os.getcwd(), "test", f"{tag}.xml") # update in-progress pool, send the xml reports if test_succeeded: log_return.update_pool("completed", instance_type, num_of_instances, job_type, test_report) else: log_return.update_pool("runtimeError", instance_type, num_of_instances, job_type, test_report) return elif use_scheduler: LOGGER.info("entered scheduler mode.") import concurrent.futures from job_requester import JobRequester job_requester = JobRequester() with concurrent.futures.ThreadPoolExecutor( max_workers=len(images)) as executor: futures = [ executor.submit(send_scheduler_requests, job_requester, image) for image in images ] for future in futures: try: future.result() except Exception as e: LOGGER.error( f"An error occurred in one of the threads: {e}") else: if not images: return pool_number = len(images) with Pool(pool_number) as p: p.map(sm_utils.execute_sagemaker_remote_tests, images)
def run_sagemaker_remote_tests(images, pytest_cache_params): """ Function to set up multiprocessing for SageMaker tests :param images: <list> List of all images to be used in SageMaker tests """ use_scheduler = os.getenv("USE_SCHEDULER", "False").lower() == "true" executor_mode = os.getenv("EXECUTOR_MODE", "False").lower() == "true" if executor_mode: LOGGER.info("entered executor mode.") import log_return num_of_instances = os.getenv("NUM_INSTANCES") image = images[0] job_type = "training" if "training" in image else "inference" instance_type = sm_utils.assign_sagemaker_remote_job_instance_type(image) test_succeeded = run_sagemaker_test_in_executor(image, num_of_instances, instance_type) tag = image.split("/")[-1].split(":")[-1] test_report = os.path.join(os.getcwd(), "test", f"{tag}.xml") # update in-progress pool, send the xml reports if test_succeeded: log_return.update_pool("completed", instance_type, num_of_instances, job_type, test_report) else: log_return.update_pool("runtimeError", instance_type, num_of_instances, job_type, test_report) return elif use_scheduler: LOGGER.info("entered scheduler mode.") import concurrent.futures from job_requester import JobRequester job_requester = JobRequester() with concurrent.futures.ThreadPoolExecutor(max_workers=len(images)) as executor: futures = [executor.submit(send_scheduler_requests, job_requester, image) for image in images] for future in futures: future.result() else: if not images: return pool_number = len(images) # Using Manager().dict() since it's a thread safe dictionary global_pytest_cache = Manager().dict() try: with Pool(pool_number) as p: p.starmap( sm_utils.execute_sagemaker_remote_tests, [[i, images[i], global_pytest_cache, pytest_cache_params] for i in range(pool_number)] ) finally: pytest_cache_util.convert_cache_json_and_upload_to_s3(global_pytest_cache, **pytest_cache_params)
def test_requester(): """ Tests the send_request and receive_logs functions of the Job Requester package. How tests are executed: - create one Job Requester object, and multiple threads. Perform send_request with the Job Requester object in each of these threads. - send messages to the SQS queue that the Job Requester object created, to imitate the response logs received back from the Job Executor. - In each of the threads, perform receive_logs to receive the log correspond to the send_request earlier. """ threads = 10 request_object = JobRequester() identifiers_list = [] input_list = [] # creating unique image names and build_context strings for _ in range(threads): input_list.append((TEST_IMAGE, "PR", 3)) # sending requests with concurrent.futures.ThreadPoolExecutor( max_workers=threads) as executor: futures = [ executor.submit(request_object.send_request, x, y, z) for (x, y, z) in input_list ] print("Created tickets......") for future in futures: res = future.result() print(res) identifiers_list.append(res) print("\n") # create sample xml report files image_tag = TEST_IMAGE.split(":")[-1] report_path = os.path.join(os.getcwd(), f"{image_tag}.xml") with open(report_path, "w") as report: report.write(SAMPLE_XML_MESSAGE) os.environ["CODEBUILD_BUILD_ARN"] = SAMPLE_CB_ARN for identifier in identifiers_list: os.environ["TICKET_KEY"] = f"folder/{identifier.ticket_name}" log_return.update_pool("completed", identifier.instance_type, 3, identifier.job_type, report_path) # receiving logs with concurrent.futures.ThreadPoolExecutor( max_workers=threads) as executor: logs = [ executor.submit(request_object.receive_logs, identifier) for identifier in identifiers_list ] LOGGER.info("Receiving logs...") for log in logs: assert "XML_REPORT" in log.result( ), f"XML Report not found as part of the returned log message." # clean up test artifacts S3 = boto3.client("s3") ticket_names = [item.ticket_name for item in identifiers_list] for name in ticket_names: S3.delete_object(Bucket=request_object.s3_ticket_bucket, Key=name) LOGGER.info("Tests passed.")