Exemplo n.º 1
0
def test_model_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
    ntm.fit(data, MINI_BATCH_SIZE)

    model = ntm.create_model()
    assert model.image == registry(REGION, "ntm") + '/ntm:1'
Exemplo n.º 2
0
def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    with pytest.raises(ValueError):
        ntm._prepare_for_training(data, 10001)
Exemplo n.º 3
0
def test_predictor_type(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
    ntm.fit(data, MINI_BATCH_SIZE)
    model = ntm.create_model()
    predictor = model.deploy(1, TRAIN_INSTANCE_TYPE)

    assert isinstance(predictor, NTMPredictor)
Exemplo n.º 4
0
def test_call_fit(base_fit, sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')

    ntm.fit(data, MINI_BATCH_SIZE)

    base_fit.assert_called_once()
    assert len(base_fit.call_args[0]) == 2
    assert base_fit.call_args[0][0] == data
    assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
Exemplo n.º 5
0
def test_all_hyperparameters(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session,
              encoder_layers=[1, 2, 3], epochs=3, encoder_layers_activation='tanh', optimizer='sgd',
              tolerance=0.05, num_patience_epochs=2, batch_norm=False, rescale_gradient=0.5, clip_gradient=0.5,
              weight_decay=0.5, learning_rate=0.5, **ALL_REQ_ARGS)
    assert ntm.hyperparameters() == dict(
        num_topics=str(ALL_REQ_ARGS['num_topics']),
        encoder_layers='[1, 2, 3]',
        epochs='3',
        encoder_layers_activation='tanh',
        optimizer='sgd',
        tolerance='0.05',
        num_patience_epochs='2',
        batch_norm='False',
        rescale_gradient='0.5',
        clip_gradient='0.5',
        weight_decay='0.5',
        learning_rate='0.5'
    )
Exemplo n.º 6
0
def test_required_hyper_parameters_value(sagemaker_session,
                                         required_hyper_parameters, value):
    with pytest.raises(ValueError):
        test_params = ALL_REQ_ARGS.copy()
        test_params[required_hyper_parameters] = value
        NTM(sagemaker_session=sagemaker_session, **test_params)
Exemplo n.º 7
0
def test_weight_decay_fail_type(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(weight_decay='other',
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 8
0
def test_weight_decay_validation_fail_value_lower(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(weight_decay=-1,
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 9
0
def test_clip_gradient_fail_type(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(clip_gradient='other',
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 10
0
def test_clip_gradient_validation_fail_value(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(clip_gradient=0,
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 11
0
def test_num_patience_epochs_validation_fail_value_upper(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(num_patience_epochs=100,
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 12
0
def test_rescale_gradient_validation_fail_value_upper(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(rescale_gradient=10,
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 13
0
def test_call_fit_none_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    ntm.fit(data)
Exemplo n.º 14
0
def test_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    assert ntm.train_image() == registry(REGION, "ntm") + '/ntm:1'
Exemplo n.º 15
0
def test_optimizer_validation_fail_value(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(optimizer='other',
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 16
0
def test_iterable_hyper_parameters_type(sagemaker_session,
                                        iterable_hyper_parameters, value):
    with pytest.raises(TypeError):
        test_params = ALL_REQ_ARGS.copy()
        test_params.update({iterable_hyper_parameters: value})
        NTM(sagemaker_session=sagemaker_session, **test_params)
Exemplo n.º 17
0
def test_encoder_layers_activation_validation_fail_value(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(encoder_layers_activation='other',
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 18
0
def test_epochs_validation_fail_value_lower(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(epochs=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
Exemplo n.º 19
0
def test_encoder_layers_validation_fail_type(sagemaker_session):
    with pytest.raises(TypeError):
        NTM(encoder_layers=0,
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 20
0
def test_num_topics_validation_fail_value_upper(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(num_topics=10000,
            sagemaker_session=sagemaker_session,
            **COMMON_TRAIN_ARGS)
Exemplo n.º 21
0
def test_num_topics_validation_fail_type(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(num_topics='other',
            sagemaker_session=sagemaker_session,
            **COMMON_TRAIN_ARGS)
Exemplo n.º 22
0
def test_tolerance_validation_fail_type(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(tolerance='other',
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 23
0
def test_optional_hyper_parameters_value(sagemaker_session,
                                         optional_hyper_parameters, value):
    with pytest.raises(ValueError):
        test_params = ALL_REQ_ARGS.copy()
        test_params.update({optional_hyper_parameters: value})
        NTM(sagemaker_session=sagemaker_session, **test_params)
Exemplo n.º 24
0
def test_tolerance_validation_fail_value_upper(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(tolerance=0.5, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
Exemplo n.º 25
0
def test_num_patience_epochs_validation_fail_type(sagemaker_session):
    with pytest.raises(ValueError):
        NTM(num_patience_epochs='other',
            sagemaker_session=sagemaker_session,
            **ALL_REQ_ARGS)
Exemplo n.º 26
0
def test_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    assert ntm.train_image() == registry(REGION, "ntm") + '/ntm:1'
Exemplo n.º 27
0
def test_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    assert image_uris.retrieve("ntm", REGION) == ntm.training_image_uri()