def test_model_image(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') lda.fit(data, MINI_BATCH_SZIE) model = lda.create_model() assert model.image == registry(REGION, 'lda') + '/lda:1'
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') with pytest.raises(ValueError): lda._prepare_for_training(data, 0)
def test_predictor_type(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') lda.fit(data, MINI_BATCH_SZIE) model = lda.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) assert isinstance(predictor, LDAPredictor)
def test_call_fit(base_fit, sagemaker_session): lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') lda.fit(data, MINI_BATCH_SZIE) base_fit.assert_called_once() assert len(base_fit.call_args[0]) == 2 assert base_fit.call_args[0][0] == data assert base_fit.call_args[0][1] == MINI_BATCH_SZIE
def test_call_fit_wrong_value_mini_batch_size(sagemaker_session): lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') with pytest.raises(ValueError): lda.fit(data, 0)
def test_all_hyperparameters(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, alpha0=2.2, max_restarts=3, max_iterations=10, tol=3.3, **ALL_REQ_ARGS) assert lda.hyperparameters() == dict( num_topics=str(ALL_REQ_ARGS['num_topics']), alpha0='2.2', max_restarts='3', max_iterations='10', tol='3.3', )
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet( "s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel="train", ) with pytest.raises(ValueError): lda._prepare_for_training(data, 0)
def test_model_image(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet( "s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel="train", ) lda.fit(data, MINI_BATCH_SZIE) model = lda.create_model() assert image_uris.retrieve("lda", REGION) == model.image_uri
def test_all_hyperparameters(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, alpha0=2.2, max_restarts=3, max_iterations=10, tol=3.3, **ALL_REQ_ARGS) assert lda.hyperparameters() == dict( num_topics=str(ALL_REQ_ARGS["num_topics"]), alpha0="2.2", max_restarts="3", max_iterations="10", tol="3.3", )
def test_init_required_named(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) assert lda.role == COMMON_TRAIN_ARGS["role"] assert lda.instance_count == INSTANCE_COUNT assert lda.instance_type == COMMON_TRAIN_ARGS["instance_type"] assert lda.num_topics == ALL_REQ_ARGS["num_topics"]
def test_init_required_named(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) assert lda.role == COMMON_TRAIN_ARGS['role'] assert lda.train_instance_count == TRAIN_INSTANCE_COUNT assert lda.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] assert lda.num_topics == ALL_REQ_ARGS['num_topics']
def test_init_required_positional(sagemaker_session): lda = LDA(ROLE, INSTANCE_TYPE, NUM_TOPICS, sagemaker_session=sagemaker_session) assert lda.role == ROLE assert lda.instance_count == INSTANCE_COUNT assert lda.instance_type == INSTANCE_TYPE assert lda.num_topics == NUM_TOPICS
def test_predictor_custom_serialization(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) data = RecordSet( "s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel="train", ) lda.fit(data, MINI_BATCH_SZIE) model = lda.create_model() custom_serializer = Mock() custom_deserializer = Mock() predictor = model.deploy( 1, INSTANCE_TYPE, serializer=custom_serializer, deserializer=custom_deserializer, ) assert isinstance(predictor, LDAPredictor) assert predictor.serializer is custom_serializer assert predictor.deserializer is custom_deserializer
def test_image(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) assert lda.train_image() == registry(REGION, 'lda') + '/lda:1'
def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() test_params.update({optional_hyper_parameters: value}) LDA(sagemaker_session=sagemaker_session, **test_params)
def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() test_params[required_hyper_parameters] = value LDA(sagemaker_session=sagemaker_session, **test_params)
def test_image(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) assert image_uris.retrieve("lda", REGION) == lda.training_image_uri()
def test_alpha0_validation_fail_type(sagemaker_session): with pytest.raises(ValueError): LDA(alpha0='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
def test_max_restarts_validation_fail_type2(sagemaker_session): with pytest.raises(ValueError): LDA(max_restarts=0.1, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
def test_max_iterations_validation_fail_value(sagemaker_session): with pytest.raises(ValueError): LDA(max_iterations=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
def test_tol_validation_fail_value(sagemaker_session): with pytest.raises(ValueError): LDA(tol=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
def test_num_topics_validation_fail_value(sagemaker_session): with pytest.raises(ValueError): LDA(num_topics=0, sagemaker_session=sagemaker_session, **COMMON_TRAIN_ARGS)