Example #1
0
    def test_init_worker(self):
        """
        Worker ID should be stored correctly
        """
        MockDatetimeNow.RETURN_VALUES = [
            datetime(1970, 1, 1, 0, 0, 0, 0),
            datetime(1970, 1, 1, 0, 1, 0, 0),
            datetime(1970, 1, 1, 0, 4, 0, 0)
        ]
        MockDatetimeNow.CURRENT_IDX = 0

        comms = WorkerComms(mp.get_context('fork'), 5)

        self.assertIsNone(comms.worker_id)
        for worker_id in [0, 1, 4]:
            with self.subTest(worker_id=worker_id, has_progress_bar=False):
                comms.init_comms(False, False)
                comms.init_worker(worker_id)
                self.assertEqual(comms.worker_id, worker_id)
                self.assertIsNone(comms._progress_bar_last_updated)
                self.assertIsNone(comms._progress_bar_n_tasks_completed)

            with self.subTest(worker_id=worker_id, has_progress_bar=True), \
                    patch('mpire.comms.datetime', new=MockDatetimeNow):
                comms.init_comms(False, True)
                comms.init_worker(worker_id)
                self.assertEqual(comms.worker_id, worker_id)
                self.assertEqual(comms._progress_bar_last_updated,
                                 datetime(1970, 1, 1, 0, worker_id, 0, 0))
                self.assertEqual(comms._progress_bar_n_tasks_completed, 0)
Example #2
0
    def test_exit_results_all_workers(self):
        """
        Test exit results related functions
        """
        comms = WorkerComms(mp.get_context('fork'), 4)
        comms.init_comms(True, False)

        # Add a few results. Every worker will always have a return value (even if it's the implicit None). Note that
        # `get_exit_results` calls `task_done`
        for worker_id in range(3):
            comms.init_worker(worker_id)
            comms.add_exit_results(worker_id)
        comms.init_worker(3)
        comms.add_exit_results(None)
        self.assertListEqual(comms.get_exit_results_all_workers(),
                             [0, 1, 2, None])

        # Should be joinable
        comms.join_results_queues()
Example #3
0
    def test_worker_restart(self):
        """
        Test worker restart related functions
        """
        comms = WorkerComms(mp.get_context('fork'), 5)
        comms.init_comms(False, False)

        # No restarts yet
        self.assertListEqual(list(comms.get_worker_restarts()), [])

        # Signal some restarts
        comms.init_worker(0)
        comms.signal_worker_restart()
        comms.init_worker(2)
        comms.signal_worker_restart()
        comms.init_worker(3)
        comms.signal_worker_restart()

        # Restarts available
        self.assertListEqual(list(comms.get_worker_restarts()), [0, 2, 3])

        # Reset some
        comms.reset_worker_restart(0)
        comms.reset_worker_restart(3)

        # Restarts available
        self.assertListEqual(list(comms.get_worker_restarts()), [2])

        # Reset last one
        comms.reset_worker_restart(2)
        self.assertListEqual(list(comms.get_worker_restarts()), [])
Example #4
0
    def test_exit_results_all_workers_exception_thrown(self):
        """
        Test exit results related functions. When an exception occurred, it should return an empty list
        """
        comms = WorkerComms(mp.get_context('fork'), 3)
        comms.init_comms(True, False)

        # Add a few results.
        for worker_id in range(3):
            comms.init_worker(worker_id)
            comms.add_exit_results(worker_id)

        # Set exception
        comms.set_exception()

        # Should return empty list
        self.assertListEqual(comms.get_exit_results_all_workers(), [])

        # Drain and join
        comms._exception_thrown.clear()
        comms.get_exit_results_all_workers()
        comms.join_results_queues()
Example #5
0
    def test_exit_results(self):
        """
        Test exit results related functions
        """
        comms = WorkerComms(mp.get_context('fork'), 3)
        comms.init_comms(True, False)

        # Nothing available yet
        for worker_id in range(3):
            with self.assertRaises(queue.Empty):
                comms.get_exit_results(worker_id, timeout=0)

        # Add a few results. Note that `get_exit_results` calls `task_done`
        for worker_id in range(3):
            comms.init_worker(worker_id)
            comms.add_exit_results(worker_id)
            comms.add_exit_results('hello world')
            comms.add_exit_results({'foo': 'bar'})
        self.assertListEqual(
            [comms.get_exit_results(worker_id=0) for _ in range(3)],
            [0, 'hello world', {
                'foo': 'bar'
            }])
        self.assertListEqual(
            [comms.get_exit_results(worker_id=1) for _ in range(3)],
            [1, 'hello world', {
                'foo': 'bar'
            }])
        self.assertListEqual(
            [comms.get_exit_results(worker_id=2) for _ in range(3)],
            [2, 'hello world', {
                'foo': 'bar'
            }])

        # Should be joinable
        comms.join_results_queues()
Example #6
0
    def test_worker_alive(self):
        """
        Test worker alive related functions
        """
        comms = WorkerComms(mp.get_context('fork'), 5)
        comms.init_comms(False, False)

        # Signal some workers are alive
        comms.init_worker(0)
        comms.set_worker_alive()
        comms.init_worker(1)
        comms.set_worker_alive()
        comms.set_worker_dead()
        comms.init_worker(2)
        comms.set_worker_alive()
        comms.init_worker(3)
        comms.set_worker_alive()

        # Check alive status
        self.assertListEqual(
            [comms.is_worker_alive(worker_id) for worker_id in range(5)],
            [True, False, True, True, False])

        # Reset some
        comms.init_worker(0)
        comms.set_worker_dead()
        comms.init_worker(3)
        comms.set_worker_dead()

        # Check alive status
        self.assertListEqual(
            [comms.is_worker_alive(worker_id) for worker_id in range(5)],
            [False, False, True, False, False])

        # We test wait by simply checking the call count
        for worker_id in range(5):
            with patch.object(comms._workers_dead[worker_id], 'wait') as p:
                comms.wait_for_dead_worker(worker_id)
                self.assertEqual(p.call_count, 1)
Example #7
0
    def test_progress_bar(self):
        """
        Test progress bar related functions
        """
        comms = WorkerComms(mp.get_context('fork'), 2)

        # Has progress bar
        self.assertFalse(comms.has_progress_bar())
        comms.init_comms(False, True)
        self.assertTrue(comms.has_progress_bar())

        # Initialize worker
        MockDatetimeNow.RETURN_VALUES = [
            datetime(1970, 1, 1, 0, 0, 0, 0),
            datetime(1970, 1, 1, 0, 0, 0, 0),
            datetime(1970, 1, 1, 0, 0, 0, 0),
            datetime(1970, 1, 1, 0, 0, 0, 0),
            datetime(1970, 1, 1, 0, 0, 0, 0)
        ]
        MockDatetimeNow.CURRENT_IDX = 0
        with patch('mpire.comms.datetime', new=MockDatetimeNow):
            comms.init_worker(0)

        # Nothing available yet
        with self.assertRaises(queue.Empty):
            comms._task_completed_queue.get(block=False)

        # 3 task done, but not enough time has passed
        with patch('mpire.comms.datetime', new=MockDatetimeNow):
            [comms.task_completed_progress_bar() for _ in range(3)]
        with self.assertRaises(queue.Empty):
            comms._task_completed_queue.get(block=False)

        # 1 more task done. Not enough time has passed, but we'll force the update. Number of tasks done should be
        # aggregated to 4
        with patch('mpire.comms.datetime', new=MockDatetimeNow):
            comms.task_completed_progress_bar(force_update=True)
        self.assertEqual(comms.get_tasks_completed_progress_bar(), (4, True))
        comms.task_done_progress_bar()

        # 3 tasks done. Enough time should've passed for each update call
        MockDatetimeNow.RETURN_VALUES = [
            datetime(1970, 1, 1, 0, 1, 0, 0),
            datetime(1970, 1, 1, 0, 2, 0, 0),
            datetime(1970, 1, 1, 0, 3, 0, 0)
        ]
        MockDatetimeNow.CURRENT_IDX = 0
        with patch('mpire.comms.datetime', new=MockDatetimeNow):
            [comms.task_completed_progress_bar() for _ in range(3)]
        self.assertListEqual(
            [comms.get_tasks_completed_progress_bar() for _ in range(3)],
            [(1, True), (1, True), (1, True)])
        [comms.task_done_progress_bar() for _ in range(3)]

        # Add poison pill
        comms.add_progress_bar_poison_pill()
        self.assertEqual(comms.get_tasks_completed_progress_bar(),
                         (POISON_PILL, True))
        comms.task_done_progress_bar()

        # Set exception
        comms.set_exception_caught()
        self.assertEqual(comms.get_tasks_completed_progress_bar(),
                         (POISON_PILL, False))

        # Should be joinable now
        comms.join_progress_bar_task_completed_queue()