def test_mxnet_neo(strftime, sagemaker_session, mxnet_version):
    mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
               framework_version=mxnet_version)

    inputs = 's3://mybucket/train'

    mx.fit(inputs=inputs)

    input_shape = {'data': [100, 1, 28, 28]}
    output_location = 's3://neo-sdk-test'

    compiled_model = mx.compile_model(target_instance_family='ml_c4', input_shape=input_shape,
                                      output_path=output_location)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ['train', 'logs_for_job', 'sagemaker_client.describe_training_job',
                                    'compile_model', 'wait_for_compilation_job']

    expected_compile_model_args = _create_compilation_job(json.dumps(input_shape), output_location)
    actual_compile_model_args = sagemaker_session.method_calls[3][2]
    assert expected_compile_model_args == actual_compile_model_args

    assert compiled_model.image == _neo_inference_image(mxnet_version)

    predictor = mx.deploy(1, CPU, use_compiled_model=True)
    assert isinstance(predictor, MXNetPredictor)

    with pytest.raises(Exception) as wrong_target:
        mx.deploy(1, CPU_C5, use_compiled_model=True)
    assert str(wrong_target.value).startswith('No compiled model for')

    # deploy without sagemaker Neo should continue to work
    mx.deploy(1, CPU)
def test_mxnet_neo(strftime, sagemaker_session, neo_mxnet_version):
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        framework_version="1.6",
        py_version="py3",
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        base_job_name="sagemaker-mxnet",
    )
    mx.fit()

    input_shape = {"data": [100, 1, 28, 28]}
    output_location = "s3://neo-sdk-test"

    compiled_model = mx.compile_model(
        target_instance_family="ml_c4",
        input_shape=input_shape,
        output_path=output_location,
        framework="mxnet",
        framework_version=neo_mxnet_version,
    )

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == [
        "train",
        "logs_for_job",
        "sagemaker_client.describe_training_job",
        "compile_model",
        "wait_for_compilation_job",
    ]

    expected_compile_model_args = _create_compilation_job(
        json.dumps(input_shape), output_location)
    actual_compile_model_args = sagemaker_session.method_calls[3][2]
    assert expected_compile_model_args == actual_compile_model_args

    assert compiled_model.image_uri == _neo_inference_image(neo_mxnet_version)

    predictor = mx.deploy(1, CPU, use_compiled_model=True)
    assert isinstance(predictor, MXNetPredictor)

    with pytest.raises(Exception) as wrong_target:
        mx.deploy(1, CPU_C5, use_compiled_model=True)
    assert str(wrong_target.value).startswith("No compiled model for")

    # deploy without sagemaker Neo should continue to work
    mx.deploy(1, CPU)