def _create_storage(self):
     temp_dir = tempfile.gettempdir()
     self._dbfile = os.path.join(temp_dir, "temp.db")
     self._engine = create_engine('sqlite:///' + self._dbfile)
     self._meta = sqla.MetaData()
     self.storage = SQLAStorage(self._engine, metadata=self._meta)
     self._meta.create_all(bind=self._engine)
Beispiel #2
0
 def _create_storage(self):
     temp_dir = tempfile.gettempdir()
     self._dbfile = os.path.join(temp_dir, "temp.db")
     conn_string = 'sqlite:///' + self._dbfile
     engine = create_engine(conn_string)
     meta = MetaData()
     self.storage = SQLAStorage(engine, metadata=meta)
     meta.create_all(bind=engine)
 def _create_storage(self):
     temp_dir = tempfile.gettempdir()
     self._dbfile = os.path.join(temp_dir, "temp.db")
     self._engine = create_engine('sqlite:///'+self._dbfile)
     self._meta = sqla.MetaData()
     self.storage = SQLAStorage(self._engine, metadata=self._meta)
     self._meta.create_all(bind=self._engine)
 def _create_storage(self):
     self._engine = create_engine(
         "postgresql+psycopg2://postgres:@localhost/flask_blogging",
         isolation_level="AUTOCOMMIT")
     self._meta = sqla.MetaData()
     self.storage = SQLAStorage(self._engine, metadata=self._meta)
     self._meta.create_all(bind=self._engine)
 def _create_storage(self):
     temp_dir = tempfile.gettempdir()
     self._dbfile = os.path.join(temp_dir, "temp.db")
     conn_string = 'sqlite:///'+self._dbfile
     engine = create_engine(conn_string)
     meta = MetaData()
     self.storage = SQLAStorage(engine, metadata=meta)
     meta.create_all(bind=engine)
    def setUp(self):
        FlaskBloggingTestCase.setUp(self)

        temp_dir = tempfile.gettempdir()
        self._dbfile = os.path.join(temp_dir, "temp.db")
        conn_string = self._conn_string(self._dbfile)
        self.app.config["SQLALCHEMY_BINDS"] = {'blog': conn_string}
        self._db = SQLAlchemy(self.app)
        self.storage = SQLAStorage(db=self._db, bind="blog")
        self._engine = self._db.get_engine(self.app, bind="blog")
        self._meta = self._db.metadata
        self._db.create_all(bind=["blog"])
Beispiel #7
0
class TestViews(FlaskBloggingTestCase):
    def _create_storage(self):
        temp_dir = tempfile.gettempdir()
        self._dbfile = os.path.join(temp_dir, "temp.db")
        conn_string = 'sqlite:///' + self._dbfile
        engine = create_engine(conn_string)
        meta = MetaData()
        self.storage = SQLAStorage(engine, metadata=meta)
        meta.create_all(bind=engine)

    def _create_blogging_engine(self):
        return BloggingEngine(self.app, self.storage)

    def setUp(self):
        FlaskBloggingTestCase.setUp(self)
        self._create_storage()
        self.app.config["BLOGGING_URL_PREFIX"] = "/blog"
        self.app.config["BLOGGING_PLUGINS"] = []
        self.engine = self._create_blogging_engine()
        self.login_manager = LoginManager(self.app)

        @self.login_manager.user_loader
        @self.engine.user_loader
        def load_user(user_id):
            return TestUser(user_id)

        @self.app.route("/login/<username>/",
                        methods=["POST"],
                        defaults={"blogger": 0})
        @self.app.route("/login/<username>/<int:blogger>/", methods=["POST"])
        def login(username, blogger):
            this_user = TestUser(username)
            login_user(this_user)
            if blogger:
                identity_changed.send(current_app._get_current_object(),
                                      identity=Identity(username))
            return redirect("/")

        @self.app.route("/logout/")
        def logout():
            logout_user()
            identity_changed.send(current_app._get_current_object(),
                                  identity=AnonymousIdentity())
            return redirect("/")

        for i in range(20):
            tags = ["hello"] if i < 10 else ["world"]
            user = "******" if i < 10 else "newuser"
            self.storage.save_post(title="Sample Title%d" % i,
                                   text="Sample Text%d" % i,
                                   user_id=user,
                                   tags=tags)

    def tearDown(self):
        os.remove(self._dbfile)

    def test_index(self):
        response = self.client.get("/blog/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog")
        self.assertEqual(response.status_code, 301)

        response = self.client.get("/blog/5/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/5/2/")
        self.assertEqual(response.status_code, 200)

    def test_post_by_id(self):
        response = self.client.get("/blog/page/1/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/page/1/sample-title/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/page/1")  # trailing slash redirect
        self.assertEqual(response.status_code, 301)

    def test_post_by_tag(self):
        response = self.client.get("/blog/tag/hello/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/tag/hello/5/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/tag/hello/5/2/")
        self.assertEqual(response.status_code, 200)

    def test_post_by_author(self):
        response = self.client.get("/blog/author/newuser/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/author/newuser/5/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/author/newuser/5/2/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/author/nonexistent_user/",
                                   follow_redirects=True)
        assert "No posts found for this user!" in str(response.data)

    def test_editor_get(self):
        user_id = "testuser"
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.get("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            self.assertEquals(current_user.get_id(), user_id)
            response = self.client.get("/blog/editor/")
            assert response.status_code == 200

            for i in range(1, 21):
                # logged in user can edit their post, and will be redirected
                # if they try to edit other's post
                response = self.client.get("/blog/editor/%d/" % i)
                expected_status_code = 200 if i <= 10 else 302
                self.assertEqual(
                    response.status_code, expected_status_code,
                    "Error for item %d %d" % (i, response.status_code))
            # logout and the access should be gone again
            self.logout()
            response = self.client.get("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.get("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

    def test_editor_post(self):
        user_id = "testuser"
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/page/21/", follow_redirects=True)
            assert "The page you are trying to access is not valid!" in \
                   str(response.data)

            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            self.assertEquals(current_user.get_id(), user_id)

            response = self.client.post("/blog/editor/",
                                        data=dict(text="Test Text",
                                                  tags="tag1, tag2"))
            # should give back the editor page
            self.assertEqual(response.status_code, 200)

            response = self.client.post("/blog/editor/",
                                        data=dict(title="Test Title",
                                                  text="Test Text",
                                                  tags="tag1, tag2"))
            self.assertEqual(response.status_code, 302)

            response = self.client.get("/blog/page/21/")
            self.assertEqual(response.status_code, 200)

    def test_editor_edit_page(self):
        user_id = "testuser"
        with self.client:
            self.login(user_id)
            response = self.client.post("/blog/editor/1/",
                                        data=dict(title="Sample Title0-Edited",
                                                  text="Sample Text0-Edited",
                                                  tags="tag1, tag2"))

            response = self.client.get("/blog/100/")
            self.assertEqual(response.status_code, 200)

            pattern = re.compile(b"<h1>.*</h1>")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), 20)
            self.assertEqual(headings[-1], b"<h1>Sample Title0-Edited</h1>")

            return

    def test_delete(self):
        user_id = "testuser"
        with self.client:

            # Anonymous user cannot delete
            response = self.client.post("/blog/delete/1/")
            self.assertEqual(response.status_code, 401)

            # a user cannot delete another person's post
            self.login(user_id)
            self.assertEquals(current_user.get_id(), user_id)
            response = self.client.post("/blog/delete/11/",
                                        follow_redirects=True)
            assert "You do not have the rights to delete this post" in \
                   str(response.data)

            # a user can delete his posts
            response = self.client.post("/blog/delete/1/",
                                        follow_redirects=True)
            assert "Your post was successfully deleted" in str(response.data)

    def login(self, user_id, blogger=False):
        if blogger:
            return self.client.post("/login/%s/1/" % user_id,
                                    follow_redirects=True)
        else:
            return self.client.post("/login/%s/" % user_id,
                                    follow_redirects=True)

    def logout(self):
        return self.client.get("/logout/")

    def test_sitemap(self):
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/sitemap.xml")
            self.assertEqual(response.status_code, 200)

    def test_atom(self):
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/feeds/all.atom.xml")
            self.assertEqual(response.status_code, 200)

    def test_posts_per_page(self):
        posts_per_page = 5
        self.app.config["BLOGGING_POSTS_PER_PAGE"] = posts_per_page
        with self.client:
            pattern = re.compile(b"<h1>.*</h1>")
            # index page
            response = self.client.get("/blog/")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), posts_per_page)

            # tag page
            response = self.client.get("/blog/tag/hello/")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), posts_per_page)

            # author page
            response = self.client.get("/blog/author/testuser/")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), posts_per_page)

    def test_url_construction(self):
        ctx = self.app.test_request_context()
        ctx.push()
        index_url = url_for("blogging.index")
        self.assertEqual(index_url, "/blog/")

        # index url
        index_url = url_for("blogging.index", count=10)
        self.assertEqual(index_url, "/blog/10/")

        index_url = url_for("blogging.index", count=10, page=2)
        self.assertEqual(index_url, "/blog/10/2/")

        # page by id
        page_url = url_for("blogging.page_by_id", post_id=5)
        self.assertEqual(page_url, "/blog/page/5/")

        # posts by tag
        tag_url = url_for("blogging.posts_by_tag", tag="hello")
        self.assertEqual(tag_url, "/blog/tag/hello/")

        # posts by author
        author_url = url_for("blogging.posts_by_author", user_id="newuser")
        self.assertEqual(author_url, "/blog/author/newuser/")

        # sitemap
        sitemap_url = url_for("blogging.sitemap")
        self.assertEqual(sitemap_url, "/blog/sitemap.xml")

        # feeds
        feed_url = url_for("blogging.feed")
        self.assertEqual(feed_url, "/blog/feeds/all.atom.xml")

        ctx.pop()

    def _set_identity_loader(self, role_name):
        @identity_loaded.connect_via(self.app)
        def on_identity_loaded(sender, identity):
            identity.user = current_user
            if hasattr(current_user, "id"):
                identity.provides.add(UserNeed(current_user.id))
            identity.provides.add(RoleNeed(role_name))

    def test_permissions_editor(self):
        self.app.config["BLOGGING_PERMISSIONS"] = True
        self.app.config["BLOGGING_PERMISSIONNAME"] = "testblogger"
        user_id = "newuser"
        self._set_identity_loader(
            self.app.config.get("BLOGGING_PERMISSIONNAME", "blogger"))

        with self.client:
            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 302)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 302)

            self.logout()

            self.login(user_id, blogger=True)
            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 200)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 200)

            test_permission = Permission(RoleNeed("testblogger"))
            blogger_permission = Permission(RoleNeed("blogger"))
            self.assertTrue(
                test_permission.issubset(self.engine.blogger_permission))
            self.assertFalse(
                blogger_permission.issubset(self.engine.blogger_permission))

    def test_permissions_delete(self):
        self.app.config["BLOGGING_PERMISSIONS"] = True
        # Assuming "BLOGGING_PERMISSIONNAME" read failure
        # self.app.config["BLOGGING_PERMISSIONNAME"] = None
        user_id = "testuser"
        self._set_identity_loader(
            self.app.config.get("BLOGGING_PERMISSIONNAME", "blogger"))

        with self.client:
            # Anonymous user cannot delete
            response = self.client.post("/blog/delete/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            # non blogger cannot delete posts
            response = self.client.post("/blog/delete/1/")
            self.assertEqual(response.status_code, 302)  # will be redirected
            self.logout()

            self.login(user_id, blogger=True)
            response = self.client.post("/blog/delete/1/",
                                        follow_redirects=True)
            assert "Your post was successfully deleted" in str(response.data)

            # a user cannot delete another person's post
            self.assertEquals(current_user.get_id(), user_id)
            response = self.client.post("/blog/delete/11/",
                                        follow_redirects=True)
            assert "You do not have the rights to delete this post" in \
                   str(response.data)

            test_permission = Permission(RoleNeed("testblogger"))
            blogger_permission = Permission(RoleNeed("blogger"))
            self.assertFalse(
                test_permission.issubset(self.engine.blogger_permission))
            self.assertTrue(
                blogger_permission.issubset(self.engine.blogger_permission))
class TestViews(FlaskBloggingTestCase):

    def _create_storage(self):
        temp_dir = tempfile.gettempdir()
        self._dbfile = os.path.join(temp_dir, "temp.db")
        conn_string = 'sqlite:///'+self._dbfile
        engine = create_engine(conn_string)
        meta = MetaData()
        self.storage = SQLAStorage(engine, metadata=meta)
        meta.create_all(bind=engine)

    def _create_blogging_engine(self):
        return BloggingEngine(self.app, self.storage)

    def setUp(self):
        FlaskBloggingTestCase.setUp(self)
        self._create_storage()
        self.app.config["BLOGGING_URL_PREFIX"] = "/blog"
        self.app.config["BLOGGING_PLUGINS"] = []
        self.engine = self._create_blogging_engine()
        self.login_manager = LoginManager(self.app)

        @self.login_manager.user_loader
        @self.engine.user_loader
        def load_user(user_id):
            return TestUser(user_id)

        @self.app.route("/login/<username>/", methods=["POST"],
                        defaults={"blogger": 0})
        @self.app.route("/login/<username>/<int:blogger>/", methods=["POST"])
        def login(username, blogger):
            this_user = TestUser(username)
            login_user(this_user)
            if blogger:
                identity_changed.send(current_app._get_current_object(),
                                      identity=Identity(username))
            return redirect("/")

        @self.app.route("/logout/")
        def logout():
            logout_user()
            identity_changed.send(current_app._get_current_object(),
                                  identity=AnonymousIdentity())
            return redirect("/")

        for i in range(20):
            tags = ["hello"] if i < 10 else ["world"]
            user = "******" if i < 10 else "newuser"
            self.storage.save_post(title="Sample Title%d" % i,
                                   text="Sample Text%d" % i,
                                   user_id=user, tags=tags)

    def tearDown(self):
        os.remove(self._dbfile)

    def test_index(self):
        response = self.client.get("/blog/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog")
        self.assertEqual(response.status_code, 301)

        response = self.client.get("/blog/5/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/5/2/")
        self.assertEqual(response.status_code, 200)

    def test_post_by_id(self):
        response = self.client.get("/blog/page/1/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/page/1/sample-title/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/page/1")  # trailing slash redirect
        self.assertEqual(response.status_code, 301)

    def test_post_by_tag(self):
        response = self.client.get("/blog/tag/hello/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/tag/hello/5/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/tag/hello/5/2/")
        self.assertEqual(response.status_code, 200)

    def test_post_by_author(self):
        response = self.client.get("/blog/author/newuser/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/author/newuser/5/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/author/newuser/5/2/")
        self.assertEqual(response.status_code, 200)

        response = self.client.get("/blog/author/nonexistent_user/",
                                   follow_redirects=True)
        assert "No posts found for this user!" in str(response.data)

    def test_editor_get(self):
        user_id = "testuser"
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.get("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            self.assertEquals(current_user.get_id(), user_id)
            response = self.client.get("/blog/editor/")
            assert response.status_code == 200

            for i in range(1, 21):
                # logged in user can edit their post, and will be redirected
                # if they try to edit other's post
                response = self.client.get("/blog/editor/%d/" % i)
                expected_status_code = 200 if i <= 10 else 302
                self.assertEqual(response.status_code, expected_status_code,
                                 "Error for item %d %d" %
                                 (i, response.status_code))
            # logout and the access should be gone again
            self.logout()
            response = self.client.get("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.get("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

    def test_editor_post(self):
        user_id = "testuser"
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/page/21/",
                                       follow_redirects=True)
            assert "The page you are trying to access is not valid!" in \
                   str(response.data)

            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            self.assertEquals(current_user.get_id(), user_id)

            response = self.client.post("/blog/editor/",
                                        data=dict(text="Test Text",
                                                  tags="tag1, tag2"))
            # should give back the editor page
            self.assertEqual(response.status_code, 200)

            response = self.client.post("/blog/editor/",
                                        data=dict(title="Test Title",
                                                  text="Test Text",
                                                  tags="tag1, tag2"))
            self.assertEqual(response.status_code, 302)

            response = self.client.get("/blog/page/21/")
            self.assertEqual(response.status_code, 200)

    def test_editor_edit_page(self):
        user_id = "testuser"
        with self.client:
            self.login(user_id)
            response = self.client.post(
                "/blog/editor/1/",
                data=dict(title="Sample Title0-Edited",
                          text="Sample Text0-Edited", tags="tag1, tag2"))

            response = self.client.get("/blog/100/")
            self.assertEqual(response.status_code, 200)

            pattern = re.compile(b"<h1>.*</h1>")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), 20)
            self.assertEqual(headings[-1], b"<h1>Sample Title0-Edited</h1>")

            return

    def test_delete(self):
        user_id = "testuser"
        with self.client:

            # Anonymous user cannot delete
            response = self.client.post("/blog/delete/1/")
            self.assertEqual(response.status_code, 401)

            # a user cannot delete another person's post
            self.login(user_id)
            self.assertEquals(current_user.get_id(), user_id)
            response = self.client.post("/blog/delete/11/",
                                        follow_redirects=True)
            assert "You do not have the rights to delete this post" in \
                   str(response.data)

            # a user can delete his posts
            response = self.client.post("/blog/delete/1/",
                                        follow_redirects=True)
            assert "Your post was successfully deleted" in str(response.data)

    def login(self, user_id, blogger=False):
        if blogger:
            return self.client.post("/login/%s/1/" % user_id,
                                    follow_redirects=True)
        else:
            return self.client.post("/login/%s/" % user_id,
                                    follow_redirects=True)

    def logout(self):
        return self.client.get("/logout/")

    def test_sitemap(self):
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/sitemap.xml")
            self.assertEqual(response.status_code, 200)

    def test_atom(self):
        with self.client:
            # access to editor should be forbidden before login
            response = self.client.get("/blog/feeds/all.atom.xml")
            self.assertEqual(response.status_code, 200)

    def test_posts_per_page(self):
        posts_per_page = 5
        self.app.config["BLOGGING_POSTS_PER_PAGE"] = posts_per_page
        with self.client:
            pattern = re.compile(b"<h1>.*</h1>")
            # index page
            response = self.client.get("/blog/")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), posts_per_page)

            # tag page
            response = self.client.get("/blog/tag/hello/")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), posts_per_page)

            # author page
            response = self.client.get("/blog/author/testuser/")
            headings = pattern.findall(response.data)
            self.assertEqual(len(headings), posts_per_page)

    def test_url_construction(self):
        ctx = self.app.test_request_context()
        ctx.push()
        index_url = url_for("blogging.index")
        self.assertEqual(index_url, "/blog/")

        # index url
        index_url = url_for("blogging.index", count=10)
        self.assertEqual(index_url, "/blog/10/")

        index_url = url_for("blogging.index", count=10, page=2)
        self.assertEqual(index_url, "/blog/10/2/")

        # page by id
        page_url = url_for("blogging.page_by_id", post_id=5)
        self.assertEqual(page_url, "/blog/page/5/")

        # posts by tag
        tag_url = url_for("blogging.posts_by_tag", tag="hello")
        self.assertEqual(tag_url, "/blog/tag/hello/")

        # posts by author
        author_url = url_for("blogging.posts_by_author", user_id="newuser")
        self.assertEqual(author_url, "/blog/author/newuser/")

        # sitemap
        sitemap_url = url_for("blogging.sitemap")
        self.assertEqual(sitemap_url, "/blog/sitemap.xml")

        # feeds
        feed_url = url_for("blogging.feed")
        self.assertEqual(feed_url, "/blog/feeds/all.atom.xml")

        ctx.pop()

    def _set_identity_loader(self):
        @identity_loaded.connect_via(self.app)
        def on_identity_loaded(sender, identity):
            identity.user = current_user
            if hasattr(current_user, "id"):
                identity.provides.add(UserNeed(current_user.id))
            identity.provides.add(RoleNeed("blogger"))

    def test_permissions_editor(self):
        self.app.config["BLOGGING_PERMISSIONS"] = True
        user_id = "newuser"
        self._set_identity_loader()

        with self.client:
            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 401)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 302)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 302)

            self.logout()

            self.login(user_id, blogger=True)
            response = self.client.post("/blog/editor/")
            self.assertEqual(response.status_code, 200)

            response = self.client.post("/blog/editor/1/")
            self.assertEqual(response.status_code, 200)

    def test_permissions_delete(self):
        self.app.config["BLOGGING_PERMISSIONS"] = True
        user_id = "testuser"
        self._set_identity_loader()

        with self.client:
            # Anonymous user cannot delete
            response = self.client.post("/blog/delete/1/")
            self.assertEqual(response.status_code, 401)

            self.login(user_id)
            # non blogger cannot delete posts
            response = self.client.post("/blog/delete/1/")
            self.assertEqual(response.status_code, 302)  # will be redirected
            self.logout()

            self.login(user_id, blogger=True)
            response = self.client.post("/blog/delete/1/",
                                        follow_redirects=True)
            assert "Your post was successfully deleted" in str(response.data)

            # a user cannot delete another person's post
            self.assertEquals(current_user.get_id(), user_id)
            response = self.client.post("/blog/delete/11/",
                                        follow_redirects=True)
            assert "You do not have the rights to delete this post" in \
                   str(response.data)
 def _create_storage(self):
     self._engine = create_engine(
         "mysql+mysqldb://root:@localhost/flask_blogging")
     self._meta = sqla.MetaData()
     self.storage = SQLAStorage(self._engine, metadata=self._meta)
     self._meta.create_all(bind=self._engine)
class TestSQLiteStorage(FlaskBloggingTestCase):
    def _create_storage(self):
        temp_dir = tempfile.gettempdir()
        self._dbfile = os.path.join(temp_dir, "temp.db")
        self._engine = create_engine('sqlite:///' + self._dbfile)
        self._meta = sqla.MetaData()
        self.storage = SQLAStorage(self._engine, metadata=self._meta)
        self._meta.create_all(bind=self._engine)

    def setUp(self):
        FlaskBloggingTestCase.setUp(self)
        self._create_storage()

    def tearDown(self):
        os.remove(self._dbfile)

    def test_post_table_exists(self):
        table_name = "post"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = [
                'id', 'title', 'text', 'post_date', 'last_modified_date',
                'draft'
            ]
            self.assertListEqual(columns, expected_columns)

    def test_tag_table_exists(self):
        table_name = "tag"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['id', 'text']
            self.assertListEqual(columns, expected_columns)

    def test_tag_post_table_exists(self):
        table_name = "tag_posts"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['tag_id', 'post_id']
            self.assertListEqual(columns, expected_columns)

    def test_user_post_table_exists(self):
        table_name = "user_posts"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['user_id', 'post_id']
            self.assertListEqual(columns, expected_columns)

    def test_user_post_table_consistency(self):
        # check if the user post table updates the user_id
        user_id = 1
        post_id = 5
        pid = self.storage.save_post(title="Title",
                                     text="Sample Text",
                                     user_id="user",
                                     tags=["hello", "world"])
        posts = self.storage.get_posts()
        self.assertEqual(len(posts), 1)
        self.storage.save_post(title="Title",
                               text="Sample Text",
                               user_id="newuser",
                               tags=["hello", "world"],
                               post_id=pid)
        self.assertEqual(len(posts), 1)
        return

    def test_tags_uniqueness(self):
        table_name = "tag"
        metadata = self._meta
        table = metadata.tables[table_name]
        with self._engine.begin() as conn:
            statement = table.insert().values(text="test_tag")
            conn.execute(statement)
        # reentering same tag should raise exception
        with self._engine.begin() as conn:
            statement = table.insert().values(text="test_tag")
            self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement)

    def test_tags_consistency(self):
        # check that when tag is updated, the posts get updated
        tags = ["hello", "world"]
        pid = self.storage.save_post(title="Title",
                                     text="Sample Text",
                                     user_id="user",
                                     tags=tags)
        post = self.storage.get_post_by_id(pid)
        self.assertEqual(len(post["tags"]), 2)
        tags.pop()
        pid = self.storage.save_post(title="Title",
                                     text="Sample Text",
                                     user_id="user",
                                     tags=tags,
                                     post_id=pid)
        post = self.storage.get_post_by_id(pid)
        self.assertEqual(len(post["tags"]), 1)

    def test_tag_post_uniqueness(self):
        self.storage.save_post(title="Title",
                               text="Sample Text",
                               user_id="user",
                               tags=["tags"])
        table_name = "tag_posts"
        metadata = self._meta
        table = metadata.tables[table_name]
        with self._engine.begin() as conn:
            statement = table.insert().values(tag_id=1, post_id=1)
            self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement)

    def test_user_post_uniqueness(self):
        pid = self.storage.save_post(title="Title1",
                                     text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"])
        table_name = "user_posts"
        metadata = sqla.MetaData()
        metadata.reflect(bind=self._engine)
        table = metadata.tables[table_name]
        # reentering same user should raise exception
        with self._engine.begin() as conn:
            statement = table.insert().values(user_id="testuser", post_id=pid)
            self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement)

    def test_bind_database(self):
        # self.storage._create_all_tables()
        self.test_post_table_exists()
        self.test_tag_table_exists()
        self.test_tag_post_table_exists()
        self.test_user_post_table_exists()

    def test_save_post(self):
        pid = self.storage.save_post(title="Title1",
                                     text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"])
        pid = self.storage.save_post(title="Title1",
                                     text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"],
                                     post_id=1)
        p = self.storage.get_post_by_id(2)
        self.assertIsNone(p)

        # invalid post_id will be treated as inserts
        pid = self.storage.save_post(title="Title1",
                                     text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"],
                                     post_id=5)
        self.assertNotEqual(pid, 5)
        self.assertEqual(pid, 2)
        p = self.storage.get_post_by_id(2)
        self.assertIsNotNone(p)

    def test_delete_post(self):
        # insert, check exists, delete, check doesn't exist anymore
        pid = self.storage.save_post(title="Title1",
                                     text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"])
        p = self.storage.get_post_by_id(pid)
        self.assertIsNotNone(p)
        self.storage.delete_post(pid)
        p = self.storage.get_post_by_id(pid)
        self.assertIsNone(p)

        # insert again.
        pid = self.storage.save_post(title="Title1",
                                     text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"],
                                     post_id=1)
        p = self.storage.get_post_by_id(pid)
        self.assertIsNotNone(p)

    def test_get_post_by_id(self):
        pid1 = self.storage.save_post(title="Title1",
                                      text="Sample Text1",
                                      user_id="testuser",
                                      tags=["hello", "world"])
        pid2 = self.storage.save_post(title="Title2",
                                      text="Sample Text2",
                                      user_id="testuser",
                                      tags=["hello", "my", "world"])

        post = self.storage.get_post_by_id(pid1)
        self._assert_post(post, "Title1", "Sample Text1", "testuser",
                          ["HELLO", "WORLD"])

        post = self.storage.get_post_by_id(pid2)
        self._assert_post(post, "Title2", "Sample Text2", "testuser",
                          ["HELLO", "MY", "WORLD"])

    def _assert_post(self, post, title, text, user_id, tags):
        tags = set([t.upper() for t in tags])
        self.assertSetEqual(set(post["tags"]), tags)
        self.assertEqual(post["title"], title)
        self.assertEqual(post["text"], text)
        self.assertEqual(post["user_id"], user_id)

    def test_get_posts(self):
        self._create_dummy_data()

        # test default queries
        posts = self.storage.get_posts()
        self.assertEqual(len(posts), 10)
        ctr = 19
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "newuser", ["world"])
            ctr -= 1

        posts = self.storage.get_posts(recent=False)
        self.assertEqual(len(posts), 10)
        ctr = 0
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "testuser", ["hello"])
            ctr += 1

        # test count and offset
        posts = self.storage.get_posts(count=5, offset=5, recent=False)
        self.assertEqual(len(posts), 5)
        ctr = 5
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "testuser", ["hello"])
            ctr += 1

        # test tag feature
        posts = self.storage.get_posts(tag="hello", recent=False)
        self.assertEqual(len(posts), 10)
        ctr = 0
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "testuser", ["hello"])
            ctr += 1
        posts = self.storage.get_posts(tag="world", recent=False)
        self.assertEqual(len(posts), 10)
        ctr = 10
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "newuser", ["world"])
            ctr += 1

        # test user_id feature
        posts = self.storage.get_posts(user_id="newuser", recent=True)
        self.assertEqual(len(posts), 10)
        ctr = 19
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "newuser", ["world"])
            ctr -= 1

        posts = self.storage.get_posts(user_id="testuser", recent=True)
        self.assertEqual(len(posts), 10)
        ctr = 9
        for post in posts:
            self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr,
                              "testuser", ["hello"])
            ctr -= 1
        return

    def test_count_posts(self):
        self._create_dummy_data()

        count = self.storage.count_posts()
        self.assertEqual(count, 20)

        # test user
        count = self.storage.count_posts(user_id="testuser")
        self.assertEqual(count, 10)

        count = self.storage.count_posts(user_id="newuser")
        self.assertEqual(count, 10)

        count = self.storage.count_posts(user_id="testuser")
        self.assertEqual(count, 10)

        # test tags
        count = self.storage.count_posts(tag="hello")
        self.assertEqual(count, 10)

        count = self.storage.count_posts(tag="world")
        self.assertEqual(count, 10)

        # multiple queries
        count = self.storage.count_posts(user_id="testuser", tag="world")
        self.assertEqual(count, 0)

    def _create_dummy_data(self):
        for i in range(20):
            tags = ["hello"] if i < 10 else ["world"]
            user = "******" if i < 10 else "newuser"
            self.storage.save_post(title="Title%d" % i,
                                   text="Sample Text%d" % i,
                                   user_id=user,
                                   tags=tags)
            time.sleep(1)
class TestSQLiteStorage(FlaskBloggingTestCase):

    def _create_storage(self):
        temp_dir = tempfile.gettempdir()
        self._dbfile = os.path.join(temp_dir, "temp.db")
        self._engine = create_engine('sqlite:///'+self._dbfile)
        self._meta = sqla.MetaData()
        self.storage = SQLAStorage(self._engine, metadata=self._meta)
        self._meta.create_all(bind=self._engine)

    def setUp(self):
        FlaskBloggingTestCase.setUp(self)
        self._create_storage()

    def tearDown(self):
        os.remove(self._dbfile)

    def test_post_table_exists(self):
        table_name = "post"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['id', 'title', 'text', 'post_date',
                                'last_modified_date', 'draft']
            self.assertListEqual(columns, expected_columns)

    def test_tag_table_exists(self):
        table_name = "tag"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['id', 'text']
            self.assertListEqual(columns, expected_columns)

    def test_tag_post_table_exists(self):
        table_name = "tag_posts"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['tag_id', 'post_id']
            self.assertListEqual(columns, expected_columns)

    def test_user_post_table_exists(self):
        table_name = "user_posts"
        with self._engine.begin() as conn:
            self.assertTrue(conn.dialect.has_table(conn, table_name))
            metadata = self._meta
            table = metadata.tables[table_name]
            columns = [t.name for t in table.columns]
            expected_columns = ['user_id', 'post_id']
            self.assertListEqual(columns, expected_columns)

    def test_user_post_table_consistency(self):
        # check if the user post table updates the user_id
        user_id = 1
        post_id = 5
        pid = self.storage.save_post(title="Title", text="Sample Text",
                                     user_id="user", tags=["hello", "world"])
        posts = self.storage.get_posts()
        self.assertEqual(len(posts), 1)
        self.storage.save_post(title="Title", text="Sample Text",
                               user_id="newuser", tags=["hello", "world"],
                               post_id=pid)
        self.assertEqual(len(posts), 1)
        return

    def test_tags_uniqueness(self):
        table_name = "tag"
        metadata = self._meta
        table = metadata.tables[table_name]
        with self._engine.begin() as conn:
            statement = table.insert().values(text="test_tag")
            conn.execute(statement)
        # reentering same tag should raise exception
        with self._engine.begin() as conn:
            statement = table.insert().values(text="test_tag")
            self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement)

    def test_tags_consistency(self):
        # check that when tag is updated, the posts get updated
        tags = ["hello", "world"]
        pid = self.storage.save_post(title="Title", text="Sample Text",
                                     user_id="user", tags=tags)
        post = self.storage.get_post_by_id(pid)
        self.assertEqual(len(post["tags"]), 2)
        tags.pop()
        pid = self.storage.save_post(title="Title", text="Sample Text",
                                     user_id="user", tags=tags, post_id=pid)
        post = self.storage.get_post_by_id(pid)
        self.assertEqual(len(post["tags"]), 1)

    def test_tag_post_uniqueness(self):
        self.storage.save_post(title="Title", text="Sample Text",
                               user_id="user", tags=["tags"])
        table_name = "tag_posts"
        metadata = self._meta
        table = metadata.tables[table_name]
        with self._engine.begin() as conn:
            statement = table.insert().values(tag_id=1, post_id=1)
            self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement)

    def test_user_post_uniqueness(self):
        pid = self.storage.save_post(title="Title1", text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"])
        table_name = "user_posts"
        metadata = sqla.MetaData()
        metadata.reflect(bind=self._engine)
        table = metadata.tables[table_name]
        # reentering same user should raise exception
        with self._engine.begin() as conn:
            statement = table.insert().values(user_id="testuser",
                                              post_id=pid)
            self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement)

    def test_bind_database(self):
        # self.storage._create_all_tables()
        self.test_post_table_exists()
        self.test_tag_table_exists()
        self.test_tag_post_table_exists()
        self.test_user_post_table_exists()

    def test_save_post(self):
        pid = self.storage.save_post(title="Title1", text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"])
        pid = self.storage.save_post(title="Title1", text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"], post_id=1)
        p = self.storage.get_post_by_id(2)
        self.assertIsNone(p)

        # invalid post_id will be treated as inserts
        pid = self.storage.save_post(title="Title1", text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"],
                                     post_id=5)
        self.assertNotEqual(pid, 5)
        self.assertEqual(pid, 2)
        p = self.storage.get_post_by_id(2)
        self.assertIsNotNone(p)

    def test_delete_post(self):
        # insert, check exists, delete, check doesn't exist anymore
        pid = self.storage.save_post(title="Title1", text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"])
        p = self.storage.get_post_by_id(pid)
        self.assertIsNotNone(p)
        self.storage.delete_post(pid)
        p = self.storage.get_post_by_id(pid)
        self.assertIsNone(p)

        # insert again.
        pid = self.storage.save_post(title="Title1", text="Sample Text",
                                     user_id="testuser",
                                     tags=["hello", "world"],
                                     post_id=1)
        p = self.storage.get_post_by_id(pid)
        self.assertIsNotNone(p)

    def test_get_post_by_id(self):
        pid1 = self.storage.save_post(title="Title1", text="Sample Text1",
                                      user_id="testuser",
                                      tags=["hello", "world"])
        pid2 = self.storage.save_post(title="Title2", text="Sample Text2",
                                      user_id="testuser",
                                      tags=["hello", "my", "world"])

        post = self.storage.get_post_by_id(pid1)
        self._assert_post(post, "Title1", "Sample Text1", "testuser",
                          ["HELLO", "WORLD"])

        post = self.storage.get_post_by_id(pid2)
        self._assert_post(post, "Title2", "Sample Text2", "testuser",
                          ["HELLO", "MY", "WORLD"])

    def _assert_post(self, post, title, text, user_id, tags):
        tags = set([t.upper() for t in tags])
        self.assertSetEqual(set(post["tags"]), tags)
        self.assertEqual(post["title"], title)
        self.assertEqual(post["text"], text)
        self.assertEqual(post["user_id"], user_id)

    def test_get_posts(self):
        self._create_dummy_data()

        # test default queries
        posts = self.storage.get_posts()
        self.assertEqual(len(posts), 10)
        ctr = 19
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "newuser", ["world"])
            ctr -= 1

        posts = self.storage.get_posts(recent=False)
        self.assertEqual(len(posts), 10)
        ctr = 0
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "testuser", ["hello"])
            ctr += 1

        # test count and offset
        posts = self.storage.get_posts(count=5, offset=5, recent=False)
        self.assertEqual(len(posts), 5)
        ctr = 5
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "testuser", ["hello"])
            ctr += 1

        # test tag feature
        posts = self.storage.get_posts(tag="hello", recent=False)
        self.assertEqual(len(posts), 10)
        ctr = 0
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "testuser", ["hello"])
            ctr += 1
        posts = self.storage.get_posts(tag="world", recent=False)
        self.assertEqual(len(posts), 10)
        ctr = 10
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "newuser", ["world"])
            ctr += 1

        # test user_id feature
        posts = self.storage.get_posts(user_id="newuser", recent=True)
        self.assertEqual(len(posts), 10)
        ctr = 19
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "newuser", ["world"])
            ctr -= 1

        posts = self.storage.get_posts(user_id="testuser", recent=True)
        self.assertEqual(len(posts), 10)
        ctr = 9
        for post in posts:
            self._assert_post(post, "Title%d" % ctr,
                              "Sample Text%d" % ctr, "testuser", ["hello"])
            ctr -= 1
        return

    def test_count_posts(self):
        self._create_dummy_data()

        count = self.storage.count_posts()
        self.assertEqual(count, 20)

        # test user
        count = self.storage.count_posts(user_id="testuser")
        self.assertEqual(count, 10)

        count = self.storage.count_posts(user_id="newuser")
        self.assertEqual(count, 10)

        count = self.storage.count_posts(user_id="testuser")
        self.assertEqual(count, 10)

        # test tags
        count = self.storage.count_posts(tag="hello")
        self.assertEqual(count, 10)

        count = self.storage.count_posts(tag="world")
        self.assertEqual(count, 10)

        # multiple queries
        count = self.storage.count_posts(user_id="testuser", tag="world")
        self.assertEqual(count, 0)

    def _create_dummy_data(self):
        for i in range(20):
            tags = ["hello"] if i < 10 else ["world"]
            user = "******" if i < 10 else "newuser"
            self.storage.save_post(title="Title%d" % i,
                                   text="Sample Text%d" % i,
                                   user_id=user, tags=tags)
            time.sleep(1)