예제 #1
0
def test_restart_on_sudden_instance_termination(training_finished,
                                                launch_train, spot_terminated,
                                                caplog):
    class DummyInstance:
        id = 1
    launch_train.return_value = 0

    # setup the AWS worker
    event_config = read_config(ramp_aws_config_template())['worker']

    worker = AWSWorker(event_config, submission='starting_kit_local')
    worker.config = event_config
    worker.submission = 'dummy submissions'
    worker.instance = DummyInstance

    # set the submission did not yet finish training
    training_finished.return_value = False
    spot_terminated.return_value = False

    worker.launch_submission()
    assert worker.status == 'running'
    assert caplog.text == ''

    # call CalledProcessError on checking if submission was finished
    training_finished.side_effect = subprocess.CalledProcessError(255, 'test')
    # make sure that the worker status is set to 'retry'
    assert worker.status == 'retry'
    assert 'Unable to connect to the instance' in caplog.text
    assert 'Adding the submission back to the queue' in caplog.text
예제 #2
0
def test_rsync_upload_fails(test_rsync):
    test_rsync.side_effect = subprocess.CalledProcessError(255, 'test')
    event_config = read_config(ramp_aws_config_template())['worker']
    instance_id = 0
    submission_name = 'test_submission'
    submissions_dir = 'temp'
    out = upload_submission(event_config, instance_id, submission_name,
                            submissions_dir)
    assert out == 1  # error ocurred and it was caught
예제 #3
0
def session_toy_aws(database_connection):
    database_config = read_config(database_config_template())
    ramp_config_aws = ramp_aws_config_template()
    try:
        deployment_dir = create_toy_db(database_config, ramp_config_aws)
        with session_scope(database_config['sqlalchemy']) as session:
            yield session
    finally:
        db, _ = setup_db(database_config['sqlalchemy'])
        Model.metadata.drop_all(db)
        shutil.rmtree(deployment_dir, ignore_errors=True)
예제 #4
0
def test_launch_ec2_instances(boto_session_cls, use_spot_instance):
    ''' Check 'use_spot_instance' config with None, True and False'''
    # dummy mock session
    session = boto_session_cls.return_value
    client = session.client.return_value
    describe_images = client.describe_images
    images = {"Images": [{"ImageId": 1, "CreationDate": 123}]}
    describe_images.return_value = images
    event_config = read_config(ramp_aws_config_template())['worker']

    event_config['use_spot_instance'] = use_spot_instance
    launch_ec2_instances(event_config)
예제 #5
0
def test_launch_ec2_instances_put_back_into_queue(test_launch_ec2_instances,
                                                  caplog):
    ''' checks if the retry status and the correct log is added if the
        api returns None instances and status retry '''

    test_launch_ec2_instances.return_value = None, 'retry'

    # setup the AWS worker
    event_config = read_config(ramp_aws_config_template())['worker']

    worker = AWSWorker(event_config, submission='starting_kit_local')
    worker.config = event_config

    # worker should be put back into the queue
    worker.setup()
    assert worker.status == 'retry'
    assert 'Adding it back to the queue and will try again' in caplog.text
예제 #6
0
def test_rsync_download_log(test_rsync, caplog):
    error = subprocess.CalledProcessError(255, 'test')
    event_config = read_config(ramp_aws_config_template())['worker']
    instance_id = 0
    submission_name = 'test_submission'

    # test for 2 errors by rsync followed by a log output
    test_rsync.side_effect = [error, error, 'test_log']
    out = download_log(event_config, instance_id, submission_name)
    assert 'Trying to download the log' in caplog.text
    assert out == 'test_log'

    # test for 3 errors by rsync followed by a log output
    test_rsync.side_effect = [error, error, error]
    with pytest.raises(subprocess.CalledProcessError):
        out = download_log(event_config, instance_id, submission_name)
    assert 'Trying to download the log' in caplog.text
예제 #7
0
def test_aws_worker_upload_error(test_launch_ec2_instances, test_rsync,
                                 caplog):
    # mock dummy AWS instance
    class DummyInstance:
        id = 1

    test_launch_ec2_instances.return_value = (DummyInstance(),), 0
    # mock the called process error
    test_rsync.side_effect = subprocess.CalledProcessError(255, 'test')

    # setup the AWS worker
    event_config = read_config(ramp_aws_config_template())['worker']

    worker = AWSWorker(event_config, submission='starting_kit_local')
    worker.config = event_config

    # CalledProcessError is thrown inside
    worker.setup()
    assert worker.status == 'error'
    assert 'Unable to connect during log download' in caplog.text
예제 #8
0
def test_aws_worker_launch_train_error(launch_train, caplog):
    # mock dummy AWS instance
    class DummyInstance:
        id = 1
    launch_train.side_effect = subprocess.CalledProcessError(255, 'test')

    # setup the AWS worker
    event_config = read_config(ramp_aws_config_template())['worker']

    worker = AWSWorker(event_config, submission='starting_kit_local')
    worker.config = event_config
    worker.submission = 'dummy submissions'
    worker.instance = DummyInstance

    # CalledProcessError is thrown inside
    status = worker.launch_submission()
    assert 'test' in caplog.text
    assert 'Cannot start training of submission' in caplog.text
    assert worker.status == 'error'
    assert status == 1
예제 #9
0
def test_creating_instances(boto_session_cls, caplog,
                            aws_msg_type, result_none, log_msg):
    ''' test launching more instances than limit on AWS enabled'''
    # info: caplog is a pytest fixture to collect logging info
    # dummy mock session of AWS
    session = boto_session_cls.return_value
    client = session.client.return_value
    describe_images = client.describe_images
    images = {"Images": [{"ImageId": 1, "CreationDate": 123}]}
    describe_images.return_value = images

    error = {
        "ClientError": {
            "Code": "Max spot instance count exceeded"
        }
    }
    event_config = read_config(ramp_aws_config_template())['worker']
    event_config['use_spot_instance'] = True
    request_spot_instances = client.request_spot_instances

    error_max_instances = botocore.exceptions.ClientError(
        error, "MaxSpotInstanceCountExceeded")
    error_unhandled = botocore.exceptions.ParamValidationError(
        report='this is temporary message')
    correct_response = {'SpotInstanceRequests':
                        [{'SpotInstanceRequestId': ['temp']}]
                        }

    if aws_msg_type == 'max_spot':
        aws_response = [error_max_instances, error_max_instances,
                        error_max_instances, error_max_instances]
    elif aws_msg_type == 'unhandled':
        aws_response = [error_unhandled, error_unhandled]
    elif aws_msg_type == 'correct':
        aws_response = [error_max_instances, correct_response]

    request_spot_instances.side_effect = aws_response
    instance, status = launch_ec2_instances(event_config)
    assert (instance is None) == result_none
    assert log_msg in caplog.text
예제 #10
0
def test_aws_worker_download_log_error(superclass, test_rsync,
                                       caplog):
    # mock dummy AWS instance
    class DummyInstance:
        id = 'test'

    test_rsync.side_effect = subprocess.CalledProcessError(255, 'test')

    # setup the AWS worker
    superclass.return_value = True
    event_config = read_config(ramp_aws_config_template())['worker']

    worker = AWSWorker(event_config, submission='starting_kit_local')
    worker.config = event_config
    worker.status = 'finished'
    worker.instance = DummyInstance
    # worker will now through an CalledProcessError
    exit_status, error_msg = worker.collect_results()
    assert 'Error occurred when downloading the logs' in caplog.text
    assert 'Trying to download the log once again' in caplog.text
    assert exit_status == 2
    assert 'test' in error_msg
    assert worker.status == 'error'
예제 #11
0
def test_dispatcher_aws_not_launching(session_toy_aws, caplog):
    # given the test config file the instance should not be able to launch
    # due to authentication error
    # after unsuccessful try the worker should teardown
    config = read_config(database_config_template())
    event_config = read_config(ramp_aws_config_template())

    dispatcher = Dispatcher(config=config,
                            event_config=event_config,
                            worker=AWSWorker,
                            n_workers=10,
                            hunger_policy='exit')
    dispatcher.fetch_from_db(session_toy_aws)
    submissions = get_submissions(session_toy_aws, 'iris_aws_test', 'new')

    dispatcher.launch_workers(session_toy_aws)
    assert 'AuthFailure' in caplog.text
    # training should not have started
    assert 'training' not in caplog.text
    num_running_workers = dispatcher._processing_worker_queue.qsize()
    assert num_running_workers == 0
    submissions2 = get_submissions(session_toy_aws, 'iris_aws_test', 'new')
    # assert that all the submissions are still in the 'new' state
    assert len(submissions) == len(submissions2)
예제 #12
0
def test_info_on_training_error(test_launch_ec2_instances, upload_submission,
                                launch_train, is_spot_terminated,
                                training_finished, training_successful,
                                get_log_content, check_instance_status,
                                download_log, session_toy_aws, caplog):
    # make sure that the Python error from the solution is passed to the
    # dispatcher
    # everything shoud be mocked as correct output from AWS instances
    # on setting up the instance and loading the submission
    # mock dummy AWS instance
    class DummyInstance:
        id = 1

    test_launch_ec2_instances.return_value = (DummyInstance(), ), 0
    upload_submission.return_value = 0
    launch_train.return_value = 0
    is_spot_terminated.return_value = 0
    training_finished.return_value = False
    download_log.return_value = 0

    config = read_config(database_config_template())
    event_config = read_config(ramp_aws_config_template())

    dispatcher = Dispatcher(config=config,
                            event_config=event_config,
                            worker=AWSWorker,
                            n_workers=10,
                            hunger_policy='exit')
    dispatcher.fetch_from_db(session_toy_aws)
    dispatcher.launch_workers(session_toy_aws)
    num_running_workers = dispatcher._processing_worker_queue.qsize()
    # worker, (submission_id, submission_name) = \
    #     dispatcher._processing_worker_queue.get()
    # assert worker.status == 'running'
    submissions = get_submissions(session_toy_aws, 'iris_aws_test', 'training')
    ids = [submissions[idx][0] for idx in range(len(submissions))]
    assert len(submissions) > 1
    assert num_running_workers == len(ids)

    dispatcher.time_between_collection = 0
    training_successful.return_value = False

    # now we will end the submission with training error
    training_finished.return_value = True
    training_error_msg = 'Python error here'
    get_log_content.return_value = training_error_msg
    check_instance_status.return_value = 'finished'

    dispatcher.collect_result(session_toy_aws)

    # the worker which we were using should have been teared down
    num_running_workers = dispatcher._processing_worker_queue.qsize()

    assert num_running_workers == 0

    submissions = get_submissions(session_toy_aws, 'iris_aws_test',
                                  'training_error')
    assert len(submissions) == len(ids)

    submission = get_submission_by_id(session_toy_aws, submissions[0][0])
    assert training_error_msg in submission.error_msg
예제 #13
0
def test_is_spot_terminated_with_CalledProcessError(test_run, caplog):
    test_run.side_effect = subprocess.CalledProcessError(28, 'test')
    event_config = read_config(ramp_aws_config_template())['worker']
    instance_id = 0
    is_spot_terminated(event_config, instance_id)
    assert 'Unable to run curl' in caplog.text