def test_train_with_no_kwargs_in_user_module(mxc):
    from mxnet_container import train
    with patch('container_support.download_s3_resource') as patched_download_s3_resource, \
            patch('container_support.untar_directory') as patched_untar_directory, \
            patch('socket.gethostbyname') as patched_gethostbyname, \
            patch('inspect.getargspec') as patched_getargspec, \
            patch('importlib.import_module', new_callable=train_no_kwargs_mock) as patched_import_module:
        patched_getargspec.return_value = getargspec_orig(NoKWArgsModule.train)

        train(optml())
        assert patched_import_module.return_value.train.called
def test_train_failing_script(mxc):
    from mxnet_container import train

    def raise_error(*args, **kwargs):
        raise ValueError("I failed")

    with patch('container_support.download_s3_resource') as patched_download_s3_resource, \
            patch('container_support.untar_directory') as patched_untar_directory, \
            patch('socket.gethostbyname') as patched_gethostbyname, \
            patch('inspect.getargspec') as patched_getargspec, \
            patch('importlib.import_module', new_callable=train_kwargs_mock) as patched_import_module:
        patched_getargspec.return_value = getargspec_orig(KWArgsModule.train)
        patched_import_module.return_value.train.side_effect = raise_error

        with pytest.raises(ValueError):
            train(optml())
        assert patched_import_module.return_value.train.called