예제 #1
0
파일: cli_test.py 프로젝트: Greunlis/gramps
class UnicodeTest(unittest.TestCase):

    def setUp(self):
        from gramps.cli.clidbman import CLIDbManager
        from gramps.gen.config import set as setconfig, get as getconfig
        from gramps.gen.dbstate import DbState
        self.newpath = os.path.join(os.path.dirname(__file__),
                                    '\u0393\u03c1\u03b1\u03bc\u03c3\u03c0')
        self.newtitle = 'Gr\u00e4mps T\u00e9st'
        os.makedirs(self.newpath)
        self.old_path = getconfig('behavior.database-path')
        setconfig('behavior.database-path', self.newpath)
        self.cli = CLIDbManager(DbState())

    def tearDown(self):
        from gramps.gen.config import set as setconfig
        for (dirpath, dirnames, filenames) in os.walk(self.newpath, False):
            for afile in filenames:
                os.remove(os.path.join(dirpath, afile))
            for adir in dirnames:
                os.rmdir(os.path.join(dirpath, adir))
        os.rmdir(self.newpath)
        setconfig('behavior.database-path', self.old_path)

    # Test that clidbman will open files in a path containing
    # arbitrary Unicode characters.
    def test4_arbitrary_uncode_path(self):
        (dbpath, title) = self.cli.create_new_db_cli(self.newtitle)

        self.assertEqual(self.newpath, os.path.dirname(dbpath),
                          "Compare paths %s and %s" % (repr(self.newpath),
                                                       repr(dbpath)))
        self.assertEqual(self.newtitle, title, "Compare titles %s and %s" %
                          (repr(self.newtitle), repr(title)))
예제 #2
0
class UnicodeTest(unittest.TestCase):

    def setUp(self):
        from gramps.cli.clidbman import CLIDbManager
        from gramps.gen.config import set as setconfig, get as getconfig
        from gramps.gen.dbstate import DbState
        self.newpath = os.path.join(os.path.dirname(__file__),
                                    '\u0393\u03c1\u03b1\u03bc\u03c3\u03c0')
        self.newtitle = 'Gr\u00e4mps T\u00e9st'
        os.makedirs(self.newpath)
        self.old_path = getconfig('database.path')
        setconfig('database.path', self.newpath)
        self.cli = CLIDbManager(DbState())

    def tearDown(self):
        from gramps.gen.config import set as setconfig
        for (dirpath, dirnames, filenames) in os.walk(self.newpath, False):
            for afile in filenames:
                os.remove(os.path.join(dirpath, afile))
            for adir in dirnames:
                os.rmdir(os.path.join(dirpath, adir))
        os.rmdir(self.newpath)
        setconfig('database.path', self.old_path)

    # Test that clidbman will open files in a path containing
    # arbitrary Unicode characters.
    def test4_arbitrary_uncode_path(self):
        (dbpath, title) = self.cli.create_new_db_cli(self.newtitle)

        self.assertEqual(self.newpath, os.path.dirname(dbpath),
                          "Compare paths %s and %s" % (repr(self.newpath),
                                                       repr(dbpath)))
        self.assertEqual(self.newtitle, title, "Compare titles %s and %s" %
                          (repr(self.newtitle), repr(title)))
예제 #3
0
 def setUp(self):        
     def dummy_callback(dummy):
         pass
     self._tmpdir = tempfile.mkdtemp()
     
     self._db = DbBsddb()
     dbman = CLIDbManager(None)
     self._filename, title = dbman.create_new_db_cli(title="Test")
     self._db.load(self._filename, dummy_callback, "w")
예제 #4
0
class TestUserCreateOwner(unittest.TestCase):
    """Test cases for the /api/user/create_owner endpoint."""
    def setUp(self):
        self.name = "Test Web API"
        self.dbman = CLIDbManager(DbState())
        _, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
        with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
            self.app = create_app()
        self.app.config["TESTING"] = True
        self.client = self.app.test_client()
        sqlauth = self.app.config["AUTH_PROVIDER"]
        sqlauth.create_table()
        self.ctx = self.app.test_request_context()
        self.ctx.push()

    def tearDown(self):
        self.ctx.pop()
        self.dbman.remove_database(self.name)

    def test_create_owner(self):
        rv = self.client.get(f"{BASE_URL}/token/create_owner/")
        assert rv.status_code == 200
        token = rv.json["access_token"]
        assert self.app.config["AUTH_PROVIDER"].get_number_users() == 0
        # data missing
        rv = self.client.post(
            f"{BASE_URL}/users/tree_owner/create_owner/",
            headers={"Authorization": "Bearer {}".format(token)},
            json={"full_name": "My Name"},
        )
        assert rv.status_code == 422
        rv = self.client.post(
            f"{BASE_URL}/users/tree_owner/create_owner/",
            headers={"Authorization": "Bearer {}".format(token)},
            json={
                "password": "******",
                "email": "*****@*****.**",
                "full_name": "My Name",
            },
        )
        assert rv.status_code == 201
        assert self.app.config["AUTH_PROVIDER"].get_number_users() == 1
        # try posting again
        rv = self.client.post(
            f"{BASE_URL}/users/tree_owner_2/create_owner/",
            headers={"Authorization": "Bearer {}".format(token)},
            json={
                "password": "******",
                "email": "*****@*****.**",
                "full_name": "My Name",
            },
        )
        assert rv.status_code == 405
        assert self.app.config["AUTH_PROVIDER"].get_number_users() == 1
        rv = self.client.get(f"{BASE_URL}/token/create_owner/")
        assert rv.status_code == 405
예제 #5
0
    def setUp(self):
        def dummy_callback(dummy):
            pass

        self._tmpdir = tempfile.mkdtemp()

        self._db = DbBsddb()
        dbman = CLIDbManager(None)
        self._filename, title = dbman.create_new_db_cli(title=cuni("Test"))
        self._db.load(self._filename, dummy_callback, "w")
예제 #6
0
class GrampsDbBaseTest(unittest.TestCase):
    """Base class for unittest that need to be able to create
    test databases."""

    def setUp(self):
        def dummy_callback(dummy):
            pass

        self.dbstate = DbState()
        self.dbman = CLIDbManager(self.dbstate)
        dirpath, name = self.dbman.create_new_db_cli("Test: bsddb", dbid="bsddb")
        self._db = make_database("bsddb")
        self._db.load(dirpath, None)

    def tearDown(self):
        self._db.close()
        self.dbman.remove_database("Test: bsddb")

    def _populate_database(
        self, num_sources=1, num_persons=0, num_families=0, num_events=0, num_places=0, num_media=0, num_links=1
    ):
        # start with sources
        sources = []
        for i in range(num_sources):
            sources.append(self._add_source())

        # now for each of the other tables. Give each entry a link
        # to num_link sources, sources are chosen on a round robin
        # basis

        for num, add_func in (
            (num_persons, self._add_person_with_sources),
            (num_families, self._add_family_with_sources),
            (num_events, self._add_event_with_sources),
            (num_places, self._add_place_with_sources),
            (num_media, self._add_media_with_sources),
        ):

            source_idx = 1
            for person_idx in range(num):

                # Get the list of sources to link
                lnk_sources = set()
                for i in range(num_links):
                    lnk_sources.add(sources[source_idx - 1])
                    source_idx = (source_idx + 1) % len(sources)

                try:
                    add_func(lnk_sources)
                except:
                    print("person_idx = ", person_idx)
                    print("lnk_sources = ", repr(lnk_sources))
                    raise

        return

    def _add_source(self, repos=None):
        # Add a Source

        with DbTxn("Add Source and Citation", self._db) as tran:
            source = Source()
            if repos is not None:
                repo_ref = RepoRef()
                repo_ref.set_reference_handle(repos.get_handle())
                source.add_repo_reference(repo_ref)
            self._db.add_source(source, tran)
            self._db.commit_source(source, tran)
            citation = Citation()
            citation.set_reference_handle(source.get_handle())
            self._db.add_citation(citation, tran)
            self._db.commit_citation(citation, tran)

        return citation

    def _add_repository(self):
        # Add a Repository

        with DbTxn("Add Repository", self._db) as tran:
            repos = Repository()
            self._db.add_repository(repos, tran)
            self._db.commit_repository(repos, tran)

        return repos

    def _add_object_with_source(self, citations, object_class, add_method, commit_method):

        object = object_class()

        with DbTxn("Add Object", self._db) as tran:
            for citation in citations:
                object.add_citation(citation.get_handle())
            add_method(object, tran)
            commit_method(object, tran)

        return object

    def _add_person_with_sources(self, citations):

        return self._add_object_with_source(citations, Person, self._db.add_person, self._db.commit_person)

    def _add_family_with_sources(self, citations):

        return self._add_object_with_source(citations, Family, self._db.add_family, self._db.commit_family)

    def _add_event_with_sources(self, citations):

        return self._add_object_with_source(citations, Event, self._db.add_event, self._db.commit_event)

    def _add_place_with_sources(self, citations):

        return self._add_object_with_source(citations, Place, self._db.add_place, self._db.commit_place)

    def _add_media_with_sources(self, citations):

        return self._add_object_with_source(citations, Media, self._db.add_media, self._db.commit_media)
예제 #7
0
def print_db_content(db):
    for h in db.get_person_handles():
        print("DB contains: person %s" % h)
    for h in db.get_source_handles():
        print("DB contains: source %s" % h)


tmpdir = tempfile.mkdtemp()
try:
    filename1 = os.path.join(tmpdir, 'test1.grdb')
    filename2 = os.path.join(tmpdir, 'test2.grdb')
    print("\nUsing Database file: %s" % filename1)
    dbstate = DbState()
    dbman = CLIDbManager(dbstate)
    dirpath, name = dbman.create_new_db_cli(filename1, dbid="bsddb")
    db = dbstate.make_database("bsddb")
    db.load(dirpath, None)
    print("Add person 1")
    add_person(db, "Anton", "Albers", True, False)
    print("Add source")
    add_source(db, "A short test", True, False)
    print("Add person 2 without commit")
    add_person(db, "Bernd", "Beta", False, False)
    print("Add source")
    add_source(db, "A short test", True, False)
    print("Add person 3")
    add_person(db, "Chris", "Connor", True, False)
    print_db_content(db)
    print("Closing Database file: %s" % filename1)
    db.close()
예제 #8
0
        tran = None

def print_db_content(db):
    for h in db.get_person_handles():
        print("DB contains: person %s" % h)
    for h in db.get_source_handles():
        print("DB contains: source %s" % h)

tmpdir = tempfile.mkdtemp()
try:
    filename1 = os.path.join(tmpdir,'test1.grdb')
    filename2 = os.path.join(tmpdir,'test2.grdb')
    print("\nUsing Database file: %s" % filename1)
    dbstate = DbState()
    dbman = CLIDbManager(dbstate)
    dirpath, name = dbman.create_new_db_cli(filename1, dbid="bsddb")
    db = make_database("bsddb")
    db.load(dirpath, None)
    print("Add person 1")
    add_person( db,"Anton", "Albers",True,False)
    print("Add source")
    add_source( db,"A short test",True,False)
    print("Add person 2 without commit")
    add_person( db,"Bernd","Beta",False,False)
    print("Add source")
    add_source( db,"A short test",True,False)
    print("Add person 3")
    add_person( db,"Chris","Connor",True,False)
    print_db_content( db)
    print("Closing Database file: %s" % filename1)
    db.close()
예제 #9
0
class GrampsDbBaseTest(unittest.TestCase):
    """Base class for unittest that need to be able to create
    test databases."""
    def setUp(self):
        def dummy_callback(dummy):
            pass

        self.dbstate = DbState()
        self.dbman = CLIDbManager(self.dbstate)
        dirpath, name = self.dbman.create_new_db_cli("Test: bsddb",
                                                     dbid="bsddb")
        self._db = self.dbstate.make_database("bsddb")
        self._db.load(dirpath, None)

    def tearDown(self):
        self._db.close()
        self.dbman.remove_database("Test: bsddb")

    def _populate_database(self,
                           num_sources=1,
                           num_persons=0,
                           num_families=0,
                           num_events=0,
                           num_places=0,
                           num_media=0,
                           num_links=1):
        # start with sources
        sources = []
        for i in range(num_sources):
            sources.append(self._add_source())

        # now for each of the other tables. Give each entry a link
        # to num_link sources, sources are chosen on a round robin
        # basis

        for num, add_func in ((num_persons, self._add_person_with_sources),
                              (num_families, self._add_family_with_sources),
                              (num_events, self._add_event_with_sources),
                              (num_places, self._add_place_with_sources),
                              (num_media, self._add_media_with_sources)):

            source_idx = 1
            for person_idx in range(num):

                # Get the list of sources to link
                lnk_sources = set()
                for i in range(num_links):
                    lnk_sources.add(sources[source_idx - 1])
                    source_idx = (source_idx + 1) % len(sources)

                try:
                    add_func(lnk_sources)
                except:
                    print("person_idx = ", person_idx)
                    print("lnk_sources = ", repr(lnk_sources))
                    raise

        return

    def _add_source(self, repos=None):
        # Add a Source

        with DbTxn("Add Source and Citation", self._db) as tran:
            source = Source()
            if repos is not None:
                repo_ref = RepoRef()
                repo_ref.set_reference_handle(repos.get_handle())
                source.add_repo_reference(repo_ref)
            self._db.add_source(source, tran)
            self._db.commit_source(source, tran)
            citation = Citation()
            citation.set_reference_handle(source.get_handle())
            self._db.add_citation(citation, tran)
            self._db.commit_citation(citation, tran)

        return citation

    def _add_repository(self):
        # Add a Repository

        with DbTxn("Add Repository", self._db) as tran:
            repos = Repository()
            self._db.add_repository(repos, tran)
            self._db.commit_repository(repos, tran)

        return repos

    def _add_object_with_source(self, citations, object_class, add_method,
                                commit_method):

        object = object_class()

        with DbTxn("Add Object", self._db) as tran:
            for citation in citations:
                object.add_citation(citation.get_handle())
            add_method(object, tran)
            commit_method(object, tran)

        return object

    def _add_person_with_sources(self, citations):

        return self._add_object_with_source(citations, Person,
                                            self._db.add_person,
                                            self._db.commit_person)

    def _add_family_with_sources(self, citations):

        return self._add_object_with_source(citations, Family,
                                            self._db.add_family,
                                            self._db.commit_family)

    def _add_event_with_sources(self, citations):

        return self._add_object_with_source(citations, Event,
                                            self._db.add_event,
                                            self._db.commit_event)

    def _add_place_with_sources(self, citations):

        return self._add_object_with_source(citations, Place,
                                            self._db.add_place,
                                            self._db.commit_place)

    def _add_media_with_sources(self, citations):

        return self._add_object_with_source(citations, Media,
                                            self._db.add_media,
                                            self._db.commit_media)
예제 #10
0
class TestUser(unittest.TestCase):
    """Test cases for the /api/user endpoints."""
    def setUp(self):
        self.name = "Test Web API"
        self.dbman = CLIDbManager(DbState())
        _, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
        with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
            self.app = create_app()
        self.app.config["TESTING"] = True
        self.client = self.app.test_client()
        sqlauth = self.app.config["AUTH_PROVIDER"]
        sqlauth.create_table()
        sqlauth.add_user(name="user",
                         password="******",
                         email="*****@*****.**",
                         role=ROLE_MEMBER)
        sqlauth.add_user(name="owner",
                         password="******",
                         email="*****@*****.**",
                         role=ROLE_OWNER)
        self.assertTrue(self.app.testing)
        self.ctx = self.app.test_request_context()
        self.ctx.push()

    def tearDown(self):
        self.ctx.pop()
        self.dbman.remove_database(self.name)

    def test_change_password_wrong_method(self):
        rv = self.client.get(BASE_URL + "/users/-/password/change")
        assert rv.status_code == 404

    def test_change_password_no_token(self):
        rv = self.client.post(
            BASE_URL + "/users/-/password/change",
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 401

    def test_change_password_wrong_old_pw(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token = rv.json["access_token"]
        rv = self.client.post(
            BASE_URL + "/users/-/password/change",
            headers={"Authorization": "Bearer {}".format(token)},
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 403

    def test_change_password(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token = rv.json["access_token"]
        rv = self.client.post(
            BASE_URL + "/users/-/password/change",
            headers={"Authorization": "Bearer {}".format(token)},
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 201
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 403
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200

    def test_change_other_user_password(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_user = rv.json["access_token"]
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_owner = rv.json["access_token"]
        # user can't change owner's PW
        rv = self.client.post(
            BASE_URL + "/users/owner/password/change",
            headers={"Authorization": "Bearer {}".format(token_user)},
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 403
        # owner can change user's PW
        rv = self.client.post(
            BASE_URL + "/users/user/password/change",
            headers={"Authorization": "Bearer {}".format(token_owner)},
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 201
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 403
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200

    def test_change_password_twice(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token = rv.json["access_token"]
        rv = self.client.post(
            BASE_URL + "/users/-/password/change",
            headers={"Authorization": "Bearer {}".format(token)},
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 201
        rv = self.client.post(
            BASE_URL + "/users/-/password/change",
            headers={"Authorization": "Bearer {}".format(token)},
            json={
                "old_password": "******",
                "new_password": "******"
            },
        )
        assert rv.status_code == 403

    def test_reset_password_trigger_invalid_user(self):
        rv = self.client.post(BASE_URL +
                              "/users/doesn_exist/password/reset/trigger/")
        assert rv.status_code == 404

    def test_reset_password_trigger_status(self):
        with patch("smtplib.SMTP") as mock_smtp:
            rv = self.client.post(BASE_URL +
                                  "/users/user/password/reset/trigger/")
            assert rv.status_code == 201

    def test_reset_password(self):
        with patch("smtplib.SMTP") as mock_smtp:
            rv = self.client.post(BASE_URL +
                                  "/users/user/password/reset/trigger/")
            context = mock_smtp.return_value
            context.send_message.assert_called()
            name, args, kwargs = context.method_calls.pop(0)
            msg = args[0]
            # extract the token from the message body
            body = msg.get_body().get_payload().replace("=\n", "")
            matches = re.findall(r".*jwt=([^\s]+).*", body)
            self.assertEqual(len(matches), 1, msg=body)
            token = matches[0]
        # try without token!
        rv = self.client.post(
            BASE_URL + "/users/-/password/reset/",
            json={"new_password": "******"},
        )
        self.assertEqual(rv.status_code, 401)
        # try empty PW!
        rv = self.client.post(
            BASE_URL + "/users/-/password/reset/",
            headers={"Authorization": "Bearer {}".format(token)},
            json={"new_password": ""},
        )
        self.assertEqual(rv.status_code, 400)
        # now that should work
        rv = self.client.post(
            BASE_URL + "/users/-/password/reset/",
            headers={"Authorization": "Bearer {}".format(token)},
            json={"new_password": "******"},
        )
        self.assertEqual(rv.status_code, 201)
        # try again with the same token!
        rv = self.client.post(
            BASE_URL + "/users/-/password/reset/",
            headers={"Authorization": "Bearer {}".format(token)},
            json={"new_password": "******"},
        )
        self.assertEqual(rv.status_code, 409)
        # old password doesn't work anymore
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 403
        # new password works!
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200

    def test_show_user(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_user = rv.json["access_token"]
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_owner = rv.json["access_token"]
        # user can view themselves
        rv = self.client.get(
            BASE_URL + "/users/-/",
            headers={"Authorization": "Bearer {}".format(token_user)},
        )
        assert rv.status_code == 200
        self.assertEqual(
            rv.json,
            {
                "name": "user",
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": None,
            },
        )
        # user cannot view others
        rv = self.client.get(
            BASE_URL + "/users/owner/",
            headers={"Authorization": "Bearer {}".format(token_user)},
        )
        assert rv.status_code == 403
        # owner can view others
        rv = self.client.get(
            BASE_URL + "/users/user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
        )
        assert rv.status_code == 200
        self.assertEqual(
            rv.json,
            {
                "name": "user",
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": None,
            },
        )

    def test_show_users(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_user = rv.json["access_token"]
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_owner = rv.json["access_token"]
        # user cannot view users
        rv = self.client.get(
            BASE_URL + "/users/",
            headers={"Authorization": "Bearer {}".format(token_user)},
        )
        assert rv.status_code == 403
        # owner can view users
        rv = self.client.get(
            BASE_URL + "/users/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
        )
        assert rv.status_code == 200
        self.assertEqual(
            set([user["name"] for user in rv.json]),
            {"user", "owner"},
        )

    def test_edit_user(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_user = rv.json["access_token"]
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_owner = rv.json["access_token"]
        # user can edit themselves
        rv = self.client.put(
            BASE_URL + "/users/-/",
            headers={"Authorization": "Bearer {}".format(token_user)},
            json={"full_name": "My Name"},
        )
        assert rv.status_code == 201
        rv = self.client.get(
            BASE_URL + "/users/-/",
            headers={"Authorization": "Bearer {}".format(token_user)},
        )
        assert rv.status_code == 200
        # email is unchanged!
        self.assertEqual(
            rv.json,
            {
                "name": "user",
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "My Name",
            },
        )
        # user cannot change others
        rv = self.client.put(
            BASE_URL + "/users/owner/",
            headers={"Authorization": "Bearer {}".format(token_user)},
            json={"full_name": "My Name"},
        )
        assert rv.status_code == 403
        # owner can edit others
        rv = self.client.put(
            BASE_URL + "/users/user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
            json={"full_name": "His Name"},
        )
        assert rv.status_code == 201
        rv = self.client.get(
            BASE_URL + "/users/user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
        )
        assert rv.status_code == 200
        self.assertEqual(
            rv.json,
            {
                "name": "user",
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "His Name",
            },
        )

    def test_add_user(self):
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
        token_user = rv.json["access_token"]
        rv = self.client.post(
            BASE_URL + "/token/",
            json={
                "username": "******",
                "password": "******"
            },
        )
        assert rv.status_code == 200
        token_owner = rv.json["access_token"]
        # user cannot add user
        rv = self.client.post(
            BASE_URL + "/users/new_user/",
            headers={"Authorization": "Bearer {}".format(token_user)},
            json={
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "My Name",
                "password": "******",
            },
        )
        assert rv.status_code == 403
        # missing password
        rv = self.client.post(
            BASE_URL + "/users/new_user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
            json={
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "My Name",
            },
        )
        assert rv.status_code == 422
        # existing user
        rv = self.client.post(
            BASE_URL + "/users/user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
            json={
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "New Name",
                "password": "******",
            },
        )
        assert rv.status_code == 409
        # OK
        rv = self.client.post(
            BASE_URL + "/users/new_user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
            json={
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "New Name",
                "password": "******",
            },
        )
        assert rv.status_code == 201
        rv = self.client.get(
            BASE_URL + "/users/new_user/",
            headers={"Authorization": "Bearer {}".format(token_owner)},
        )
        assert rv.status_code == 200
        # email is unchanged!
        self.assertEqual(
            rv.json,
            {
                "email": "*****@*****.**",
                "role": ROLE_MEMBER,
                "full_name": "New Name",
                "name": "new_user",
            },
        )
        # check token for new user
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        assert rv.status_code == 200
예제 #11
0
class TestConfig(unittest.TestCase):
    """Test cases for the /api/config/ endpoints."""
    def setUp(self):
        self.name = "Test Web API"
        self.dbman = CLIDbManager(DbState())
        _, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
        with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
            self.app = create_app()
        self.app.config["TESTING"] = True
        self.client = self.app.test_client()
        sqlauth = self.app.config["AUTH_PROVIDER"]
        sqlauth.create_table()
        sqlauth.add_user(name="user",
                         password="******",
                         email="*****@*****.**",
                         role=ROLE_MEMBER)
        sqlauth.add_user(name="admin",
                         password="******",
                         email="*****@*****.**",
                         role=ROLE_OWNER)
        self.ctx = self.app.test_request_context()
        self.ctx.push()
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        self.header_member = {
            "Authorization": f"Bearer {rv.json['access_token']}"
        }
        rv = self.client.post(BASE_URL + "/token/",
                              json={
                                  "username": "******",
                                  "password": "******"
                              })
        self.header_owner = {
            "Authorization": f"Bearer {rv.json['access_token']}"
        }

    def tearDown(self):
        self.ctx.pop()
        self.dbman.remove_database(self.name)

    def test_get_config(self):
        rv = self.client.get(
            f"{BASE_URL}/config/",
            headers=self.header_member,
        )
        assert rv.status_code == 403
        rv = self.client.get(
            f"{BASE_URL}/config/",
            headers=self.header_owner,
        )
        assert rv.status_code == 200
        assert rv.json == {}

    def test_set_config_unauth(self):
        rv = self.client.put(
            f"{BASE_URL}/config/EMAIL_HOST/",
            headers=self.header_member,
            json={"value": "myhost"},
        )
        assert rv.status_code == 403

    def test_set_config_put(self):
        rv = self.client.put(
            f"{BASE_URL}/config/EMAIL_HOST/",
            headers=self.header_owner,
            json={"value": "host1"},
        )
        assert rv.status_code == 200
        rv = self.client.get(f"{BASE_URL}/config/EMAIL_HOST/",
                             headers=self.header_owner)
        assert rv.status_code == 200
        assert rv.json == "host1"

    def test_config_delete(self):
        rv = self.client.put(
            f"{BASE_URL}/config/EMAIL_HOST/",
            headers=self.header_owner,
            json={"value": "host2"},
        )
        rv = self.client.get(f"{BASE_URL}/config/EMAIL_HOST/",
                             headers=self.header_owner)
        assert rv.status_code == 200
        assert rv.json == "host2"
        rv = self.client.delete(
            f"{BASE_URL}/config/EMAIL_HOST/",
            headers=self.header_owner,
        )
        rv = self.client.get(f"{BASE_URL}/config/EMAIL_HOST/",
                             headers=self.header_owner)
        assert rv.status_code == 404

    def test_config_reset_password(self):
        """Check that the config options are picked up in the reset email."""
        def get_from_host():
            with patch("smtplib.SMTP_SSL") as mock_smtp:
                self.client.post(
                    f"{BASE_URL}/users/user/password/reset/trigger/")
                context = mock_smtp.return_value
                context.send_message.assert_called()
                name, args, kwargs = context.method_calls.pop(0)
                msg = args[0]
                body = msg.get_body().get_payload().replace("=\n", "")
                matches = re.findall(r".*(https?://[^/]+)/api", body)
                host = matches[0]
                return msg["From"], host

        from_email, host = get_from_host()
        assert from_email == ""
        assert host == "http://localhost"
        self.client.put(
            f"{BASE_URL}/config/BASE_URL/",
            headers=self.header_owner,
            json={"value": "https://www.example.com"},
        )
        self.client.put(
            f"{BASE_URL}/config/DEFAULT_FROM_EMAIL/",
            headers=self.header_owner,
            json={"value": "*****@*****.**"},
        )
        from_email, host = get_from_host()
        assert from_email == "*****@*****.**"
        assert host == "https://www.example.com"