def test_signals(self): rollouts = defaultdict(dict) rollout_ids = defaultdict(dict) tasks = defaultdict(dict) classes = DelayTask, SequentialExecTask, ParallelExecTask signals = ('abort_rollout',), ('term_rollout',), ('skip_rollback', 'term_rollout') for cls in classes: for sigs in signals: rollout = Rollout({}) rollout.save() rollouts[cls][sigs] = rollout rollout_ids[cls][sigs] = rollout.id if cls is DelayTask: task = create_task(rollout, DelayTask, seconds=15) tasks[cls][sigs] = task else: t1, t2, t3 = [create_task(rollout, DelayTask, seconds=15) for _ in range(3)] task = create_task(rollout, cls, [t1, t2, t3]) tasks[cls][sigs] = t3 for cls in classes: for sigs in signals: rollout = rollouts[cls][sigs] rollout.rollout_async() # Enough time for rollouts to start time.sleep(0.5) for cls in classes: for sigs in signals: id = rollout_ids[cls][sigs] for sig in sigs: self.assertTrue(Rollout._can_signal(id, sig)) self.assertTrue(Rollout._do_signal(id, sig)) self.assertTrue(Rollout._is_signalling(id, sig)) # Enough time for rollouts to finish and save to db time.sleep(2) for cls in classes: for sigs in signals: rollout_id = rollout_ids[cls][sigs] rollout = Rollout._from_id(rollout_id) self.assertTrue(rollout.rollout_finish_dt, 'Rollout for %s not finished when sent %s' % (cls, sigs)) task = tasks[cls][sigs] task = Task._from_id(task.id) # Sequential exec's last task should not have run due to aborts if cls is SequentialExecTask: self.assertFalse(task.run_start_dt, 'Final task %s run for %s rollout when sent %s' % (task, cls, sigs)) else: self.assertTrue(task.run_start_dt, 'Final task %s not run for %s rollout when sent %s' % (task, cls, sigs)) # If rollbacks were skipped the root task should not have reverted if 'skip_rollback' in sigs: self.assertFalse(rollout.root_task.revert_start_dt, 'Rollout for %s rolled back when sent %s' % (cls, sigs)) else: self.assertTrue(rollout.root_task.revert_start_dt, 'Rollout for %s not rolled back when sent %s' % (cls, sigs))
def thread_wrapped_task(): with inner_thread_nested_setup(outer_handlers): try: # Reload from db from kettle.tasks import Task task = Task._from_id(task_id) getattr(task, method_name)() except Exception: # TODO: Fix logging print traceback.format_exc() logbook.exception() abort.set()