Пример #1
0
def test_model_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
    ntm.fit(data, MINI_BATCH_SIZE)

    model = ntm.create_model()
    assert model.image == registry(REGION, "ntm") + '/ntm:1'
Пример #2
0
def test_model_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
    ntm.fit(data, MINI_BATCH_SIZE)

    model = ntm.create_model()
    assert model.image == registry(REGION, "ntm") + '/ntm:1'
Пример #3
0
def test_call_fit_wrong_value_upper_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    with pytest.raises(ValueError):
        ntm.fit(data, 10001)
Пример #4
0
def test_predictor_type(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
    ntm.fit(data, MINI_BATCH_SIZE)
    model = ntm.create_model()
    predictor = model.deploy(1, TRAIN_INSTANCE_TYPE)

    assert isinstance(predictor, NTMPredictor)
Пример #5
0
def test_predictor_type(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
    ntm.fit(data, MINI_BATCH_SIZE)
    model = ntm.create_model()
    predictor = model.deploy(1, TRAIN_INSTANCE_TYPE)

    assert isinstance(predictor, NTMPredictor)
Пример #6
0
def test_call_fit_none_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm",
              sagemaker_session=sagemaker_session,
              **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX),
                     num_records=1,
                     feature_dim=FEATURE_DIM,
                     channel='train')
    ntm.fit(data)
Пример #7
0
def test_call_fit(base_fit, sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')

    ntm.fit(data, MINI_BATCH_SIZE)

    base_fit.assert_called_once()
    assert len(base_fit.call_args[0]) == 2
    assert base_fit.call_args[0][0] == data
    assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
Пример #8
0
def test_call_fit(base_fit, sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')

    ntm.fit(data, MINI_BATCH_SIZE)

    base_fit.assert_called_once()
    assert len(base_fit.call_args[0]) == 2
    assert base_fit.call_args[0][0] == data
    assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
Пример #9
0
def test_model_image(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    ntm.fit(data, MINI_BATCH_SIZE)

    model = ntm.create_model()
    assert image_uris.retrieve("ntm", REGION) == model.image_uri
Пример #10
0
def test_predictor_custom_serialization(sagemaker_session):
    ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    ntm.fit(data, MINI_BATCH_SIZE)
    model = ntm.create_model()
    custom_serializer = Mock()
    custom_deserializer = Mock()
    predictor = model.deploy(
        1,
        INSTANCE_TYPE,
        serializer=custom_serializer,
        deserializer=custom_deserializer,
    )

    assert isinstance(predictor, NTMPredictor)
    assert predictor.serializer is custom_serializer
    assert predictor.deserializer is custom_deserializer
Пример #11
0
def test_call_fit_none_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
                     channel='train')
    ntm.fit(data)