예제 #1
0
def get_display_mode(context: Context, db: Db, owner: Message.author, result: str) -> \
        Tuple[Optional[str], Optional[str], Optional[int]]:
    """
    Checks the owner's display mode.
    If it's EDIT, it updates the owner's whole list in place.
    If it's POST, it pulls the owner's whole list and returns that as a new text post.
    If it's UPDATE, it updates the owner's whole list in place, and also posts the update text as a new text post.
    :param context: Discord message Context object
    :param db: Opened Db object
    :param owner:
    :param result: Output generated by the particular function. e.g., "[ ] Task name  (1)"
    :return: text to post, text used to edit last message, last message ID
    """
    list_output = f"{owner.name}'s list\n" + print_task_list(db, owner)
    display_mode = db.get_display_mode(owner.id)
    if display_mode == "POST":
        return list_output, None, None
    last_message_id = db.get_last_message_id(owner.id, context.channel.id)
    if last_message_id is None:
        return list_output, None, None
    if display_mode == "EDIT":
        return None, list_output, last_message_id
    if display_mode == "UPDATE":
        return result, list_output, last_message_id
    raise ValueError(f"{display_mode} is an unrecognized display mode.")
예제 #2
0
def add_tasks(context, owner: Message.author, message):
    with Db() as db:
        task_list = get_list_items(db, owner)
        new_row_ids = db.add_tasks(message.split("\n"), owner.id)
        db.update_list_items(task_list + new_row_ids, owner.id)
        return get_display_mode(context, db, owner,
                                f"{owner.name} added tasks to their list")
예제 #3
0
def remove_tasks(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        for item in message.split("\n"):
            task_id = find_task_id_in_list(db, task_ids, item)
            task_ids.remove(task_id)
        db.update_list_items(task_ids, owner.id)
        return get_display_mode(context, db, owner,
                                f"{owner.name} removed tasks from their list")
예제 #4
0
def uncheck_task(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        task_id = find_task_id_in_list(db, task_ids, message)
        if db.get_task_state(task_id) != "CHECKED":
            return "That task hasn't been completed.", None, None
        db.uncomplete_task(task_id)
        return get_display_mode(context, db, owner,
                                f"{owner.name} unchecked task '{message}'")
예제 #5
0
def task_time(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        tasks = db.get_tasks(task_ids)
        output = f"{owner.name} times\n"
        for t in tasks:
            output += f"{t['name']}: {pretty_task_time(t['time_spent_sec'])}\n"
        output += f"\nTotal time: {pretty_task_time(sum([t['time_spent_sec'] for t in tasks]))}"
    return output, None, None
예제 #6
0
def check_all(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        for task_id in task_ids:
            if db.get_task_state(task_id) != "CHECKED":
                db.complete_task(task_id)
        return get_display_mode(
            context, db, owner,
            f"{owner.name} finished all tasks in their list.")
예제 #7
0
def stop_task(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        task_id = find_task_id_in_list(db, task_ids, message)
        if db.get_task_state(task_id) != "STARTED":
            return "That task hasn't been started.", None, None
        db.stop_task(task_id)
        return get_display_mode(context, db, owner,
                                f"{owner.name} stopped task '{message}'")
예제 #8
0
def new_list(context, owner: Message.author, message):
    if message == "":
        return "I need items to make a list. Put each separate item on a new line."
    task_names = message.split("\n")
    with Db() as db:
        db.add_owner(owner.id, owner.name)
        task_ids = db.add_tasks(task_names, owner.id)
        db.new_list(task_ids, owner.id)
        return f"Created a new list for {owner.name}\n" + print_task_list(
            db, owner), None, None
예제 #9
0
def start_task(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        task_id = find_task_id_in_list(db, task_ids, message)
        if db.get_task_state(task_id) == "STARTED":
            return "That task is already started.", None, None
        if "STARTED" in db.get_task_states(task_ids):
            return "You can only have one started task at a time.", None, None
        db.start_task(task_id)
        return get_display_mode(context, db, owner,
                                f"{owner.name} started task '{message}'")
예제 #10
0
def find_task_id_in_list(db: Db, task_ids: List[int], item: str) -> int:
    """
    :param db:
    :param task_ids:
    :param item:
    :return:
    """
    if item.isdigit():
        try:
            return task_ids[int(item) - 1]
        except IndexError:
            raise ListBotError(f"{item} is not a valid list position.")
    else:
        filtered_task_ids = db.filter_task_ids_by_name(task_ids, item)
        if len(filtered_task_ids) == 0:
            raise ListBotError(f"Couldn't find any list item matching \"{item}\"")
        if len(filtered_task_ids) > 1:
            task_names = db.get_task_names(filtered_task_ids)
            raise ListBotError(f"Multiple items found matching \"{item}\":\n" + "\n".join(task_names))
        return filtered_task_ids[0]
예제 #11
0
def clear_checked_tasks(context, owner: Message.author, message):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        task_states = db.get_task_states(task_ids)
        tasks = zip(task_ids, task_states)
        unchecked_tasks = []
        for task_id, state in tasks:
            if state != "CHECKED":
                unchecked_tasks.append(task_id)
        db.update_list_items(unchecked_tasks, owner.id)
        return f"{owner.name}'s list\n" + print_task_list(db,
                                                          owner), None, None
예제 #12
0
def _reorder_task(context, owner: Message.author, item: str, position: int):
    with Db() as db:
        task_ids = get_list_items(db, owner)
        task_id = find_task_id_in_list(db, task_ids, item)
        task_ids.remove(task_id)
        if position == -1:
            task_ids.append(task_id)
        else:
            task_ids.insert(position - 1, task_id)
        db.update_list_items(task_ids, owner.id)
        return get_display_mode(
            context, db, owner,
            f"{owner.name} moved task '{item}' to {position}")
예제 #13
0
def check_list(context, owner: Message.author, message):
    with Db() as db:
        list_items = get_list_items(db, owner)
        tasks = message.split(" ")
        if not all([s.isdigit() for s in tasks]):
            return "All arguments must be positions of tasks, not names.", None, None
        task_ids = []
        for task in tasks:
            task_id = find_task_id_in_list(db, list_items, task)
            if db.get_task_state(task_id) == "CHECKED":
                return f"Task {task} has already been finished.", None, None
            task_ids.append(task_id)
        for task_id in task_ids:
            db.complete_task(task_id)
        return get_display_mode(context, db, owner,
                                f"{owner.name} finished tasks '{message}'")
예제 #14
0
def check_version():
    with Db() as db:
        db.check_version()
예제 #15
0
class TestDatabase(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        # Change pwd to root folder, same as when main.py is run
        os.chdir("..")

    def setUp(self):
        self.db = Db()

    def tearDown(self):
        self.db.conn.rollback()
        self.db.conn.close()

    def test_add_user(self):
        self.db.add_owner(OWNER_ID1, OWNER_NAME1)
        self.assertEqual(OWNER_ID1, self.db.get_owner_id(OWNER_NAME1))

    def test_add_user_twice(self):
        self.db.add_owner(OWNER_ID1, OWNER_NAME1)
        self.db.add_owner(OWNER_ID1, OWNER_NAME2)
        self.assertEqual(OWNER_ID1, self.db.get_owner_id(OWNER_NAME1))

        sql = "SELECT COUNT(*) FROM owners WHERE name=?"
        self.assertEqual(1, self.db.cur.execute(sql, [OWNER_NAME1]).fetchone()[0])
        self.assertEqual(0, self.db.cur.execute(sql, [OWNER_NAME2]).fetchone()[0])

    def test_get_display_mode(self):
        self.db.add_owner(OWNER_ID1, OWNER_NAME1)
        self.assertEqual("EDIT", self.db.get_display_mode(OWNER_ID1))
        self.db.set_display_mode(OWNER_ID1, "POST")
        self.assertEqual("POST", self.db.get_display_mode(OWNER_ID1))
        with self.assertRaisesRegex(ValueError, r"^BAD_VALUE is not a valid display mode. Valid values are "
                                                r"\['POST', 'EDIT', 'UPDATE'\]$"):
            self.db.set_display_mode(OWNER_ID1, "BAD_VALUE")
        # Set display_mode to a bad value
        sql = "UPDATE owners SET display_mode=? WHERE id=?"
        self.db.cur.execute(sql, ["BAD_VALUE", OWNER_ID1])
        with self.assertRaisesRegex(ValueError, r"^BAD_VALUE is not a valid display mode. Valid values are "
                                                r"\['POST', 'EDIT', 'UPDATE'\]$"):
            self.db.get_display_mode(OWNER_ID1)

    def test_add_task(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        retrieved_rowid = self.db.filter_task_ids_by_name([rowid], "spain")
        self.assertEqual(rowid, retrieved_rowid[0])

    def test_add_multiple_tasks(self):
        rowids = self.db.add_tasks(["Spoon Spain", "Spank Spain", "Split Spain", "Flog France"], OWNER_ID1)
        retrieved_rowids = self.db.filter_task_ids_by_name([rowids[0], rowids[2], rowids[3]], "spain")
        self.assertEqual(rowids[0], retrieved_rowids[0])
        self.assertEqual(rowids[2], retrieved_rowids[1])
        self.assertEqual(2, len(retrieved_rowids))

    def test_get_task_name(self):
        rowid1 = self.db.add_task("Spoon Spain", OWNER_ID1)
        rowid2 = self.db.add_task("Spank Spain", OWNER_ID1)
        self.assertEqual("Spoon Spain", self.db.get_task_name(rowid1))
        self.assertEqual("Spank Spain", self.db.get_task_name(rowid2))

    def test_get_task_names(self):
        rowid1 = self.db.add_task("Spoon Spain", OWNER_ID1)
        rowid2 = self.db.add_task("Spank Spain", OWNER_ID1)
        rowid3 = self.db.add_task("Splain Spain", OWNER_ID1)
        self.assertEqual(['Spoon Spain', 'Splain Spain'], self.db.get_task_names([rowid1, rowid3]))

    def test_rename_task(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        self.db.rename_task(rowid, "Fork France")
        self.assertEqual("Fork France", self.db.get_task_name(rowid))

    def test_start_task(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        self.assertEqual("NOT_STARTED", self.db.get_task_state(rowid))
        self.db.start_task(rowid)
        self.assertEqual("STARTED", self.db.get_task_state(rowid))
        started_time = ts_to_epoch(self.db.get_task_start_time(rowid))
        # Accept any epoch that's within one second of the current epoch.
        self.assertAlmostEqual(now_epoch(), started_time, delta=1)

    def test_stop_task(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        five_seconds_ago = epoch_to_ts(now_epoch() - 5)
        sql = """
        UPDATE tasks
        SET 
            state = 'STARTED',
            started_ts = ?
        WHERE 
          rowid = ?;
        """
        self.db.cur.execute(sql, [five_seconds_ago, rowid])
        self.assertEqual("STARTED", self.db.get_task_state(rowid))
        self.db.stop_task(rowid)
        self.assertEqual("NOT_STARTED", self.db.get_task_state(rowid))
        self.assertIsNone(self.db.get_task_start_time(rowid))
        self.assertAlmostEqual(5, self.db.get_time_spent_sec(rowid), delta=1)

    def test_complete_task(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        five_seconds_ago = epoch_to_ts(now_epoch() - 5)
        sql = """
        UPDATE tasks
        SET 
            state = 'STARTED',
            started_ts = ?
        WHERE 
          rowid = ?;
        """
        self.db.cur.execute(sql, [five_seconds_ago, rowid])
        self.assertEqual("STARTED", self.db.get_task_state(rowid))
        self.db.complete_task(rowid)
        self.assertEqual("CHECKED", self.db.get_task_state(rowid))
        completed_time = ts_to_epoch(self.db.get_task_complete_time(rowid))
        # Accept any epoch that's within one second of the current epoch.
        self.assertAlmostEqual(now_epoch(), completed_time, delta=1)
        self.assertAlmostEqual(5, self.db.get_time_spent_sec(rowid), delta=1)

    def test_complete_task_not_started(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        five_seconds_ago = epoch_to_ts(now_epoch() - 5)
        sql = "UPDATE tasks SET started_ts = ? WHERE rowid = ?;"
        self.db.cur.execute(sql, [five_seconds_ago, rowid])
        self.db.complete_task(rowid)
        self.assertEqual(5, self.db.get_time_spent_sec(rowid))

    def test_uncomplete_task(self):
        rowid = self.db.add_task("Spoon Spain", OWNER_ID1)
        self.db.complete_task(rowid)
        self.assertIsNotNone(self.db.get_task_complete_time(rowid))
        self.db.uncomplete_task(rowid)
        self.assertIsNone(self.db.get_task_complete_time(rowid))

    def test_get_tasks(self):
        rowid1 = self.db.add_task("Spoon Spain", OWNER_ID1)
        rowid2 = self.db.add_task("Spank Spain", OWNER_ID1)
        rowid3 = self.db.add_task("Splain Spain", OWNER_ID1)
        self.db.start_task(rowid1)
        self.db.complete_task(rowid1)
        self.db.start_task(rowid2)
        current_time = now_ts()
        expected_result = [
            {'name': 'Spoon Spain',
             'owner_id': OWNER_ID1,
             'state': 'CHECKED',
             'created_ts': current_time,
             'started_ts': current_time,
             'completed_ts': current_time,
             'time_spent_sec': 0},
            {'name': 'Spank Spain',
             'owner_id': OWNER_ID1,
             'state': 'STARTED',
             'created_ts': current_time,
             'started_ts': current_time,
             'completed_ts': None,
             'time_spent_sec': 0},
            {'name': 'Splain Spain',
             'owner_id': OWNER_ID1,
             'state': 'NOT_STARTED',
             'created_ts': current_time,
             'started_ts': None,
             'completed_ts': None,
             'time_spent_sec': 0}
        ]
        self.assertEqual(expected_result, self.db.get_tasks([rowid1, rowid2, rowid3]))
        self.assertEqual(["CHECKED", "STARTED", "NOT_STARTED"], self.db.get_task_states([rowid1, rowid2, rowid3]))

    def test_get_list_items_no_list(self):
        self.assertIsNone(self.db.get_list_items(OWNER_ID1))

    def test_add_new_list(self):
        self.db.new_list([1, 2, 3], OWNER_ID1)
        current_time = now_ts()
        retrieved_list = self.db.get_list(OWNER_ID1)
        self.assertEqual([[1, 2, 3], OWNER_ID1, current_time, current_time], retrieved_list)
        retrieved_list = self.db.get_list_items(OWNER_ID1)
        self.assertEqual([1, 2, 3], retrieved_list)

    def test_add_and_replace_new_list(self):
        self.db.new_list([1, 2, 3], OWNER_ID1)
        self.db.new_list([4, 5, 6], OWNER_ID1)
        retrieved_list = self.db.get_list_items(OWNER_ID1)
        self.assertEqual([4, 5, 6], retrieved_list)

    def test_update_list(self):
        self.db.new_list([1, 2, 3], OWNER_ID1)
        self.db.update_list_items([7, 8, 9], OWNER_ID1)
        retrieved_list = self.db.get_list_items(OWNER_ID1)
        self.assertEqual([7, 8, 9], retrieved_list)

    def test_get_last_message_id(self):
        self.db.new_list([1, 2, 3], OWNER_ID1)
        channel_name1 = "channel_name1"
        last_message_id1 = 1111111
        channel_name2 = "channel_name2"
        last_message_id2 = 2222222
        self.assertIsNone(self.db.get_last_message_id(OWNER_ID1, channel_name1))
        self.db.set_last_message_id(OWNER_ID1, channel_name1, last_message_id1)
        self.assertEqual(last_message_id1, self.db.get_last_message_id(OWNER_ID1, channel_name1))
        self.assertIsNone(self.db.get_last_message_id(OWNER_ID1, channel_name2))
        self.db.set_last_message_id(OWNER_ID1, channel_name2, last_message_id2)
        self.assertEqual(last_message_id1, self.db.get_last_message_id(OWNER_ID1, channel_name1))
        self.assertEqual(last_message_id2, self.db.get_last_message_id(OWNER_ID1, channel_name2))
 def tearDown(self):
     with Db() as db:
         db.wipe_owner_data(OWNER.id)
예제 #17
0
def set_last_message_id(context: Context, owner: Message.author,
                        response_message: Message):
    with Db() as db:
        db.set_last_message_id(owner.id, context.channel.id,
                               response_message.id)
예제 #18
0
 def setUp(self):
     self.db = Db()
예제 #19
0
def show_list(context, owner: Message.author, message):
    with Db() as db:
        return f"{owner.name}'s list\n" + print_task_list(db,
                                                          owner), None, None