def test_module_default_input_fn_with_accelerator(decode, mx_ndarray_iter, mx_ndarray, mx_eia): ndarray = Mock(shape=(1, (1, ))) ndarray.as_in_context.return_value = ndarray mx_ndarray.return_value = ndarray model = Mock(data_shapes=[(1, (1, ))]) DefaultModuleInferenceHandler().default_input_fn(Mock(), 'application/json', model) ndarray.as_in_context.assert_called_with(mx.cpu())
def test_module_default_input_fn_with_npy(decode, mx_ndarray_iter): model = Mock(data_shapes=[(1, (1, ))]) input_data = Mock() content_type = 'application/x-npy' DefaultModuleInferenceHandler().default_input_fn(input_data, content_type, model) decode.assert_called_with(input_data, content_type) init_call = call(mx.nd.array([0]), batch_size=1, last_batch_handle='pad') assert init_call in mx_ndarray_iter.mock_calls
def test_module_default_input_fn_with_csv(decode, mx_ndarray_iter, mx_ndarray): ndarray = Mock(shape=(1, (1, ))) ndarray.reshape.return_value = ndarray ndarray.as_in_context.return_value = ndarray mx_ndarray.return_value = ndarray model = Mock(data_shapes=[(1, (1, ))]) input_data = Mock() content_type = 'text/csv' DefaultModuleInferenceHandler().default_input_fn(input_data, content_type, model) decode.assert_called_with(input_data, content_type) ndarray.reshape.assert_called_with((1, )) init_call = call(mx.nd.array([0]), batch_size=1, last_batch_handle='pad') assert init_call in mx_ndarray_iter.mock_calls
def __init__(self): super(MXNetModuleTransformer, self).__init__(DefaultModuleInferenceHandler())
def test_module_default_predict_fn(): module = Mock() data = Mock() DefaultModuleInferenceHandler().default_predict_fn(data, module) module.predict.assert_called_with(data)
def test_module_default_input_fn_invalid_content_type(): with pytest.raises(errors.UnsupportedFormatError) as e: DefaultModuleInferenceHandler().default_input_fn( None, 'bad/content-type') e.match('Content type bad/content-type is not supported by this framework')
def test_default_module_valid_content_types(): assert DefaultModuleInferenceHandler().VALID_CONTENT_TYPES == \ (content_types.JSON, content_types.CSV, content_types.NPY)