def test_module_default_input_fn_with_accelerator(decode, mx_ndarray_iter,
                                                  mx_ndarray, mx_eia):
    ndarray = Mock(shape=(1, (1, )))
    ndarray.as_in_context.return_value = ndarray
    mx_ndarray.return_value = ndarray

    model = Mock(data_shapes=[(1, (1, ))])
    DefaultModuleInferenceHandler().default_input_fn(Mock(),
                                                     'application/json', model)

    ndarray.as_in_context.assert_called_with(mx.cpu())
def test_module_default_input_fn_with_npy(decode, mx_ndarray_iter):
    model = Mock(data_shapes=[(1, (1, ))])

    input_data = Mock()
    content_type = 'application/x-npy'
    DefaultModuleInferenceHandler().default_input_fn(input_data, content_type,
                                                     model)

    decode.assert_called_with(input_data, content_type)
    init_call = call(mx.nd.array([0]), batch_size=1, last_batch_handle='pad')
    assert init_call in mx_ndarray_iter.mock_calls
def test_module_default_input_fn_with_csv(decode, mx_ndarray_iter, mx_ndarray):
    ndarray = Mock(shape=(1, (1, )))
    ndarray.reshape.return_value = ndarray
    ndarray.as_in_context.return_value = ndarray
    mx_ndarray.return_value = ndarray

    model = Mock(data_shapes=[(1, (1, ))])

    input_data = Mock()
    content_type = 'text/csv'
    DefaultModuleInferenceHandler().default_input_fn(input_data, content_type,
                                                     model)

    decode.assert_called_with(input_data, content_type)
    ndarray.reshape.assert_called_with((1, ))
    init_call = call(mx.nd.array([0]), batch_size=1, last_batch_handle='pad')
    assert init_call in mx_ndarray_iter.mock_calls
Ejemplo n.º 4
0
 def __init__(self):
     super(MXNetModuleTransformer, self).__init__(DefaultModuleInferenceHandler())
def test_module_default_predict_fn():
    module = Mock()
    data = Mock()

    DefaultModuleInferenceHandler().default_predict_fn(data, module)
    module.predict.assert_called_with(data)
def test_module_default_input_fn_invalid_content_type():
    with pytest.raises(errors.UnsupportedFormatError) as e:
        DefaultModuleInferenceHandler().default_input_fn(
            None, 'bad/content-type')
    e.match('Content type bad/content-type is not supported by this framework')
def test_default_module_valid_content_types():
    assert DefaultModuleInferenceHandler().VALID_CONTENT_TYPES == \
        (content_types.JSON, content_types.CSV, content_types.NPY)