예제 #1
0
class TestRedisSseStream(TestCase, SseStreamTestBase):
    @classmethod
    def setUpClass(cls):
        patch_all()

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        self.channel = 'channel1'
        self.stream = SseStream(self.channel, self.cache)

    def tearDown(self):
        self.cache.clear()
예제 #2
0
class TestRedisSseStream(TestCase, SseStreamTestBase):
    @classmethod
    def setUpClass(cls):
        patch_all()

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        self.channel = 'channel1'
        self.stream = SseStream(self.channel, self.cache)

    def tearDown(self):
        self.cache.clear()
예제 #3
0
class TestWorkflowReceiver(TestCase):
    @classmethod
    def setUpClass(cls):
        initialize_test_config()
        server_secret_file = os.path.join(
            walkoff.config.Config.ZMQ_PRIVATE_KEYS_PATH, "server.key_secret")
        server_public, server_secret = auth.load_certificate(
            server_secret_file)
        client_secret_file = os.path.join(
            walkoff.config.Config.ZMQ_PRIVATE_KEYS_PATH, "client.key_secret")
        client_public, client_secret = auth.load_certificate(
            client_secret_file)
        cls.key = PrivateKey(
            client_secret[:nacl.bindings.crypto_box_SECRETKEYBYTES])
        cls.server_key = PrivateKey(
            server_secret[:nacl.bindings.crypto_box_SECRETKEYBYTES]).public_key
        cls.box = Box(cls.key, cls.server_key)

    @patch.object(walkoff.cache,
                  'make_cache',
                  return_value=MockRedisCacheAdapter())
    def test_init(self, mock_make_cache):
        receiver = WorkflowReceiver(self.key, self.server_key,
                                    walkoff.config.Config.CACHE)
        self.assertEqual(receiver.key, self.key)
        self.assertEqual(receiver.server_key, self.server_key)
        mock_make_cache.assert_called_once_with(walkoff.config.Config.CACHE)
        self.assertIsInstance(receiver.cache, MockRedisCacheAdapter)
        self.assertFalse(receiver.exit)

    @patch.object(walkoff.cache,
                  'make_cache',
                  return_value=MockRedisCacheAdapter())
    def get_receiver(self, mock_create_cache):
        return WorkflowReceiver(self.key, self.server_key,
                                walkoff.config.Config.CACHE)

    def test_shutdown(self):
        receiver = self.get_receiver()
        with patch.object(receiver.cache, 'shutdown') as mock_shutdown:
            receiver.shutdown()
            self.assertTrue(receiver.exit)
            mock_shutdown.assert_called_once()

    def test_receive_workflow_no_message(self):
        receiver = self.get_receiver()
        workflow_generator = receiver.receive_workflows()
        workflow = next(workflow_generator)
        self.assertIsNone(workflow)

    def check_workflow_message(self, message, expected):
        receiver = self.get_receiver()
        encrypted_message = self.box.encrypt(message.SerializeToString())
        workflow_generator = receiver.receive_workflows()
        receiver.cache.lpush('request_queue', encrypted_message)
        workflow = next(workflow_generator)
        self.assertTupleEqual(workflow, expected)

    def test_receive_workflow_basic_workflow(self):
        workflow_id = str(uuid4())
        execution_id = str(uuid4())
        message = ExecuteWorkflowMessage()
        message.workflow_id = workflow_id
        message.workflow_execution_id = execution_id
        message.resume = True
        self.check_workflow_message(message,
                                    (workflow_id, execution_id, '', [], True))

    def test_receive_workflow_with_start(self):
        workflow_id = str(uuid4())
        execution_id = str(uuid4())
        start = str(uuid4())
        message = ExecuteWorkflowMessage()
        message.workflow_id = workflow_id
        message.workflow_execution_id = execution_id
        message.resume = True
        message.start = start
        self.check_workflow_message(
            message, (workflow_id, execution_id, start, [], True))

    def test_receive_workflow_with_arguments(self):
        workflow_id = str(uuid4())
        execution_id = str(uuid4())
        start = str(uuid4())
        ref = str(uuid4())
        arguments = [{
            'name': 'arg1',
            'value': 42
        }, {
            'name': 'arg2',
            'reference': ref,
            'selection': ['a', 1]
        }]
        message = ExecuteWorkflowMessage()
        message.workflow_id = workflow_id
        message.workflow_execution_id = execution_id
        message.resume = True
        message.start = start
        arg = message.arguments.add()
        arg.name = arguments[0]['name']
        arg.value = str(arguments[0]['value'])
        arg = message.arguments.add()
        arg.name = arguments[1]['name']
        arg.reference = arguments[1]['reference']
        arg.selection = str(arguments[1]['selection'])

        receiver = self.get_receiver()
        encrypted_message = self.box.encrypt(message.SerializeToString())
        workflow_generator = receiver.receive_workflows()
        receiver.cache.lpush('request_queue', encrypted_message)
        workflow = next(workflow_generator)
        workflow_arguments = workflow[3]
        self.assertEqual(workflow_arguments[0].name, arguments[0]['name'])
        self.assertEqual(workflow_arguments[0].value,
                         str(arguments[0]['value']))
        self.assertEqual(workflow_arguments[1].name, arguments[1]['name'])
        self.assertEqual(workflow_arguments[1].reference, ref)
        self.assertEqual(workflow_arguments[1].selection,
                         str(arguments[1]['selection']))

    def test_receive_workflow_exit(self):
        receiver = self.get_receiver()
        workflow_generator = receiver.receive_workflows()
        receiver.exit = True
        with self.assertRaises(StopIteration):
            next(workflow_generator)
예제 #4
0
 def setUp(self):
     self.cache = MockRedisCacheAdapter()
     self.channel = 'channel1'
     self.stream = SseStream(self.channel, self.cache)
 def setUpClass(cls):
     initialize_test_config()
     cls.subscriptions = [Subscription(str(uuid4()), ['a', 'b', 'c']), Subscription(str(uuid4()), ['b'])]
     cls.cache = MockRedisCacheAdapter()
     cls.controller = WorkflowExecutionController(cls.cache)
     setup_dbs()
예제 #6
0
class TestConsoleStream(ServerTestCase):

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        console_stream.cache = self.cache

    def tearDown(self):
        self.cache.clear()

    def test_format_console_data(self):
        sender = {'name': 'workflow1', 'execution_id': 'abc-def-ghi'}
        data = {'app_name': 'App1', 'action_name': 'action1', 'level': logging.WARN, 'message': 'some_message'}
        expected = copy(data)
        expected['workflow'] = 'workflow1'
        expected['level'] = logging.getLevelName(logging.WARN)
        self.assertEqual(format_console_data(sender, data=data), expected)

    @patch.object(console_stream, 'publish')
    def test_console_log_callback(self, mock_publish):
        sender = {'name': 'workflow1', 'execution_id': 'abc-def-ghi'}
        data = {'app_name': 'App1', 'action_name': 'action1', 'level': 'WARN', 'message': 'some_message'}
        console_log_callback(sender, data=data)
        expected = format_console_data(sender, data=data)
        mock_publish.assert_called_once_with(expected, event='log', subchannels=sender['execution_id'])

    def call_stream(self, execution_id=None):
        post = self.test_client.post('/api/auth', content_type="application/json",
                                     data=json.dumps(dict(username='******', password='******')), follow_redirects=True)
        key = json.loads(post.get_data(as_text=True))['access_token']
        url = '/api/streams/console/log?access_token={}'.format(key)
        if execution_id:
            url += '&workflow_execution_id={}'.format(execution_id)
        return self.test_client.get(url)

    @patch.object(console_stream, 'stream')
    def test_stream_endpoint(self, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        execution_id = str(uuid4())
        response = self.call_stream(execution_id=execution_id)
        mock_stream.assert_called_once_with(subchannel=execution_id)
        self.assertEqual(response.status_code, SUCCESS)

    @patch.object(console_stream, 'stream')
    def test_stream_endpoint_invalid_uuid(self, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        response = self.call_stream(execution_id='invalid')
        mock_stream.assert_not_called()
        self.assertEqual(response.status_code, BAD_REQUEST)

    @patch.object(console_stream, 'stream')
    def test_stream_endpoint_no_execution_id(self, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        response = self.call_stream()
        mock_stream.assert_not_called()
        self.assertEqual(response.status_code, BAD_REQUEST)

    @patch.object(console_stream, 'stream')
    def check_stream_endpoint_no_key(self, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        response = self.test_client.get('/api/streams/console/log?access_token=invalid')
        mock_stream.assert_not_called()
        self.assertEqual(response.status_code, 422)
class TestWorkflowResultsStream(ServerTestCase):

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        workflow_stream.cache = self.cache
        action_stream.cache = self.cache

    def tearDown(self):
        self.cache.clear()
        for status in self.app.running_context.execution_db.session.query(WorkflowStatus).all():
            self.app.running_context.execution_db.session.delete(status)
        self.app.running_context.execution_db.session.commit()

    def assert_and_strip_timestamp(self, data, field='timestamp'):
        timestamp = data.pop(field, None)
        self.assertIsNotNone(timestamp)

    @staticmethod
    def get_sample_action_sender():
        argument_id = str(uuid4())
        action_id = str(uuid4())
        action_execution_id = str(uuid4())
        arguments = [{'name': 'a', 'value': '42'},
                     {'name': 'b', 'reference': argument_id, 'selection': json.dumps(['a', '1'])}]
        return {
            'action_name': 'some_action_name',
            'app_name': 'HelloWorld',
            'id': action_id,
            'name': 'my_name',
            'execution_id': action_execution_id,
            'arguments': arguments
        }

    @staticmethod
    def get_action_kwargs(with_result=False):
        workflow_id = str(uuid4())
        ret = {'workflow': {'execution_id': workflow_id}}
        if with_result:
            ret['data'] = {'result': 'some result'}
        return ret

    def test_format_action_data(self):
        workflow_id = str(uuid4())
        kwargs = {'data': {'workflow': {'execution_id': workflow_id}}}
        sender = self.get_sample_action_sender()
        status = ActionStatusEnum.executing
        result = format_action_data(sender, kwargs, status)
        expected = sender
        expected['action_id'] = expected.pop('id')
        expected['workflow_execution_id'] = workflow_id
        expected['status'] = status.name
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_action_data_with_results(self):
        workflow_id = str(uuid4())
        kwargs = {'data': {'workflow': {'execution_id': workflow_id},
                           'data': {'result': 'some result'}}}
        sender = self.get_sample_action_sender()
        status = ActionStatusEnum.executing
        result = format_action_data_with_results(sender, kwargs, status)
        expected = sender
        expected['action_id'] = expected.pop('id')
        expected['workflow_execution_id'] = workflow_id
        expected['status'] = status.name
        expected['result'] = 'some result'
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_action_data_with_long_results(self):
        size_limit = 128
        self.app.config['MAX_STREAM_RESULTS_SIZE_KB'] = size_limit
        workflow_id = str(uuid4())
        kwargs = {'data': {'workflow': {'execution_id': workflow_id},
                           'data': {'result': 'x' * 1024 * 2 * size_limit}}}  # should exceed limit
        sender = self.get_sample_action_sender()
        status = ActionStatusEnum.executing
        result = format_action_data_with_results(sender, kwargs, status)
        expected = sender
        expected['action_id'] = expected.pop('id')
        expected['workflow_execution_id'] = workflow_id
        expected['status'] = status.name
        expected['result'] = {'truncated': 'x' * 1024 * size_limit}
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def check_action_callback(self, callback, status, event, mock_publish, mock_summary, with_result=False):
        sender = self.get_sample_action_sender()
        kwargs = self.get_action_kwargs(with_result=with_result)
        if not with_result:
            expected = format_action_data(deepcopy(sender), {'data': kwargs}, status)
        else:
            expected = format_action_data_with_results(deepcopy(sender), {'data': kwargs}, status)
        summary = {key: expected[key] for key in action_summary_keys}
        callback(sender, data=kwargs)
        for result, mocked in zip((expected, summary), (mock_publish, mock_summary)):
            result.pop('timestamp')
            mocked.assert_called_once()
            mocked.call_args[0][0].pop('timestamp')
            mocked.assert_called_with(result, event=event, subchannels=(kwargs['workflow']['execution_id'], 'all'))

    @patch.object(action_summary_stream, 'publish')
    @patch.object(action_stream, 'publish')
    def test_action_started_callback(self, mock_publish, mock_summary):
        self.check_action_callback(
            action_started_callback,
            ActionStatusEnum.executing,
            'started',
            mock_publish,
            mock_summary)

    @patch.object(action_summary_stream, 'publish')
    @patch.object(action_stream, 'publish')
    def test_action_ended_callback(self, mock_publish, mock_summary):
        self.check_action_callback(
            action_ended_callback,
            ActionStatusEnum.success,
            'success',
            mock_publish,
            mock_summary,
            with_result=True)

    @patch.object(action_summary_stream, 'publish')
    @patch.object(action_stream, 'publish')
    def test_action_error_callback(self, mock_publish, mock_summary):
        self.check_action_callback(
            action_error_callback,
            ActionStatusEnum.failure,
            'failure',
            mock_publish,
            mock_summary,
            with_result=True)

    @patch.object(action_summary_stream, 'publish')
    @patch.object(action_stream, 'publish')
    def test_action_args_invalid_callback(self, mock_publish, mock_summary):
        self.check_action_callback(
            action_error_callback,
            ActionStatusEnum.failure,
            'failure',
            mock_publish,
            mock_summary,
            with_result=True)

    @patch.object(action_summary_stream, 'publish')
    @patch.object(action_stream, 'publish')
    def test_trigger_waiting_data_action_callback(self, mock_publish, mock_summary):
        self.check_action_callback(
            trigger_awaiting_data_action_callback,
            ActionStatusEnum.awaiting_data,
            'awaiting_data',
            mock_publish,
            mock_summary
        )

    @staticmethod
    def get_workflow_sender(execution_id=None):
        execution_id = execution_id or str(uuid4())
        workflow_id = str(uuid4())
        return {'execution_id': execution_id, 'id': workflow_id, 'name': 'workflow1'}

    def test_format_workflow_result(self):
        execution_id = str(uuid4())
        workflow_id = str(uuid4())
        sender = {'execution_id': execution_id, 'id': workflow_id, 'name': 'workflow1'}
        result = format_workflow_result(sender, WorkflowStatusEnum.pending)
        self.assert_and_strip_timestamp(result)
        sender['workflow_id'] = sender.pop('id')
        sender['status'] = WorkflowStatusEnum.pending.name
        self.assertDictEqual(result, sender)

    def get_workflow_status(self, workflow_execution_id, status):
        workflow_id = uuid4()
        workflow_status = WorkflowStatus(workflow_execution_id, workflow_id, 'workflow1')
        action_execution_id = uuid4()
        action_id = uuid4()
        self.app.running_context.execution_db.session.add(workflow_status)
        action_status = ActionStatus(action_execution_id, action_id, 'my action', 'the_app', 'the_action')
        self.app.running_context.execution_db.session.add(action_status)
        workflow_status.add_action_status(action_status)
        expected = {
            'execution_id': str(workflow_execution_id),
            'workflow_id': str(workflow_id),
            'name': 'workflow1',
            'status': status.name,
            'current_action': action_status.as_json(summary=True)}
        return expected, workflow_status

    def test_format_workflow_result_with_current_step(self):
        workflow_execution_id = uuid4()
        expected, _ = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.running)

        result = format_workflow_result_with_current_step(workflow_execution_id, WorkflowStatusEnum.running)
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_workflow_result_with_current_step_mismatched_status(self):
        workflow_execution_id = uuid4()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.running)
        status.paused()
        result = format_workflow_result_with_current_step(workflow_execution_id, WorkflowStatusEnum.running)
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_workflow_result_with_current_step_no_result_found(self):
        workflow_execution_id = uuid4()
        expected = {'execution_id': str(workflow_execution_id), 'status': WorkflowStatusEnum.paused.name}
        result = format_workflow_result_with_current_step(workflow_execution_id, WorkflowStatusEnum.paused)
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def check_workflow_callback(self, callback, sender, status, event, mock_publish, expected=None, **kwargs):
        if not expected:
            expected = format_workflow_result(deepcopy(sender), status)
            expected.pop('timestamp')
        callback(sender, **kwargs)
        mock_publish.assert_called_once()
        self.assert_and_strip_timestamp(mock_publish.call_args[0][0])
        mock_publish.assert_called_with(expected, event=event, subchannels=(expected['execution_id'], 'all'))

    @patch.object(workflow_stream, 'publish')
    def test_workflow_pending_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_pending_callback,
            sender,
            WorkflowStatusEnum.pending,
            'queued',
            mock_publish)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_started_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_started_callback,
            sender,
            WorkflowStatusEnum.running,
            'started',
            mock_publish)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_paused_callback(self, mock_publish):
        workflow_execution_id = uuid4()
        sender = self.get_workflow_sender(execution_id=str(workflow_execution_id))
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.paused)
        self.check_workflow_callback(
            workflow_paused_callback,
            sender,
            WorkflowStatusEnum.paused,
            'paused',
            mock_publish,
            expected=expected)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_resumed_callback(self, mock_publish):
        workflow_execution_id = uuid4()

        class MockWorkflowSender(object):
            def get_execution_id(self):
                return workflow_execution_id

        sender = MockWorkflowSender()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.running)
        self.check_workflow_callback(
            workflow_resumed_callback,
            sender,
            WorkflowStatusEnum.running,
            'resumed',
            mock_publish,
            expected=expected,
            data={"execution_id": workflow_execution_id})

    @patch.object(workflow_stream, 'publish')
    def test_trigger_awaiting_data_workflow_callback(self, mock_publish):
        workflow_execution_id = uuid4()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.awaiting_data)
        self.check_workflow_callback(
            trigger_awaiting_data_workflow_callback,
            None,
            WorkflowStatusEnum.awaiting_data,
            'awaiting_data',
            mock_publish,
            expected=expected,
            data={'workflow': {'execution_id': str(workflow_execution_id)}})

    @patch.object(workflow_stream, 'publish')
    def test_trigger_action_taken_callback(self, mock_publish):
        workflow_execution_id = uuid4()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.pending)
        self.check_workflow_callback(
            trigger_action_taken_callback,
            None,
            WorkflowStatusEnum.pending,
            'triggered',
            mock_publish,
            expected=expected,
            data={'workflow_execution_id': str(workflow_execution_id)})

    @patch.object(workflow_stream, 'publish')
    def test_workflow_aborted_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_aborted_callback,
            sender,
            WorkflowStatusEnum.aborted,
            'aborted',
            mock_publish)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_shutdown_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_shutdown_callback,
            sender,
            WorkflowStatusEnum.completed,
            'completed',
            mock_publish)

    def check_stream_endpoint(self, endpoint, mock_stream, execution_id=None, summary=False):
        mock_stream.return_value = Response('something', status=SUCCESS)
        post = self.test_client.post('/api/auth', content_type="application/json",
                                     data=json.dumps(dict(username='******', password='******')), follow_redirects=True)
        key = json.loads(post.get_data(as_text=True))['access_token']
        url = '/api/streams/workflowqueue/{}?access_token={}'.format(endpoint, key)
        if execution_id:
            url += '&workflow_execution_id={}'.format(execution_id)
        if summary:
            url += '&summary=true'
        response = self.test_client.get(url)
        if execution_id is None:
            execution_id = 'all'
        if execution_id != 'invalid':
            mock_stream.assert_called_once_with(subchannel=execution_id)
            self.assertEqual(response.status_code, SUCCESS)
        else:
            mock_stream.assert_not_called()
            self.assertEqual(response.status_code, BAD_REQUEST)

    def check_stream_endpoint_no_key(self, endpoint, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        response = self.test_client.get('/api/streams/workflowqueue/{}?access_token=invalid'.format(endpoint))
        mock_stream.assert_not_called()
        self.assertEqual(response.status_code, 422)

    @patch.object(action_stream, 'stream')
    def test_action_stream_endpoint(self, mock_stream):
        self.check_stream_endpoint('actions', mock_stream)

    @patch.object(action_stream, 'stream')
    def test_action_stream_endpoint_with_execution_id(self, mock_stream):
        execution_id = str(uuid4())
        self.check_stream_endpoint('actions', mock_stream, execution_id=execution_id)

    @patch.object(action_stream, 'stream')
    def test_action_stream_endpoint_with_invalid_execution_id(self, mock_stream):
        self.check_stream_endpoint('actions', mock_stream, execution_id='invalid')

    @patch.object(action_summary_stream, 'stream')
    def test_action_stream_endpoint_with_summary(self, mock_stream):
        self.check_stream_endpoint('actions', mock_stream, summary=True)

    @patch.object(action_summary_stream, 'stream')
    def test_action_stream_endpoint_with_execution_id_with_summary(self, mock_stream):
        execution_id = str(uuid4())
        self.check_stream_endpoint('actions', mock_stream, execution_id=execution_id, summary=True)

    @patch.object(workflow_stream, 'stream')
    def test_workflow_stream_endpoint(self, mock_stream):
        self.check_stream_endpoint('workflow_status', mock_stream)

    @patch.object(workflow_stream, 'stream')
    def test_workflow_stream_endpoint_with_execution_id(self, mock_stream):
        execution_id = str(uuid4())
        self.check_stream_endpoint('workflow_status', mock_stream, execution_id=execution_id)

    @patch.object(workflow_stream, 'stream')
    def test_workflow_stream_endpoint_with_invalid_execution_id(self, mock_stream):
        self.check_stream_endpoint('workflow_status', mock_stream, execution_id='invalid')

    @patch.object(action_stream, 'stream')
    def test_action_stream_endpoint_invalid_key(self, mock_stream):
        self.check_stream_endpoint_no_key('actions', mock_stream)

    @patch.object(workflow_stream, 'stream')
    def test_workflow_stream_endpoint_invalid_key(self, mock_stream):
        self.check_stream_endpoint_no_key('workflow_status', mock_stream)
예제 #8
0
 def setUp(self):
     self.cache = MockRedisCacheAdapter()
     self.channel = 'channel1'
     self.stream = FilteredSseStream(self.channel, self.cache)
 def setUpClass(cls):
     cls.redis_cache = MockRedisCacheAdapter()
예제 #10
0
 def setUpClass(cls):
     initialize_test_config()
     cls.execution_db = execution_db_help.setup_dbs()
     cls.cache = MockRedisCacheAdapter()
예제 #11
0
 def setUp(self):
     self.cache = MockRedisCacheAdapter()
     sse_stream.cache = self.cache
예제 #12
0
class TestNotificationStream(ServerTestCase):

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        sse_stream.cache = self.cache

    def tearDown(self):
        self.cache.clear()

    @staticmethod
    def get_standard_message_and_user():
        message = MockMessage(1, 'sub', [MockUser(1, 'uname'), MockUser(2, 'admin2')], datetime.utcnow(), False)
        user = MockUser(3, 'uname2')
        return message, user

    @staticmethod
    def _format_user_dict(user):
        return {'data': {'user': user}}

    def assert_timestamp_is_not_none(self, formatted):
        timestamp = formatted.pop('timestamp', None)
        self.assertIsNotNone(timestamp)

    def test_format_read_responded_data(self):
        message, user = self.get_standard_message_and_user()
        formatted = format_read_responded_data(message, user)
        self.assert_timestamp_is_not_none(formatted)
        self.assertDictEqual(formatted, {'id': 1, 'username': '******'})

    @patch.object(sse_stream, 'publish')
    def test_message_created_callback(self, mock_publish):
        message, user = self.get_standard_message_and_user()
        result, ids = message_created_callback(message, **self._format_user_dict(user))
        self.assertSetEqual(ids, {1, 2})
        expected = {
            'id': message.id,
            'subject': message.subject,
            'created_at': message.created_at.isoformat(),
            'is_read': False,
            'awaiting_response': message.requires_response}
        self.assertDictEqual(result, expected)
        mock_publish.assert_called_once_with(result, subchannels=ids, event=NotificationSseEvent.created.name)

    @patch.object(sse_stream, 'publish')
    def test_message_read_callback(self, mock_publish):
        message, user = self.get_standard_message_and_user()
        result, ids = message_read_callback(message, **self._format_user_dict(user))
        self.assertSetEqual(ids, {1, 2})
        mock_publish.assert_called_once_with(result, subchannels=ids, event=NotificationSseEvent.read.name)
        self.assert_timestamp_is_not_none(result)
        self.assertDictEqual(result, {'id': message.id, 'username': user.username})

    @patch.object(sse_stream, 'publish')
    def test_message_responded_callback(self, mock_publish):
        message, user = self.get_standard_message_and_user()
        result, ids = message_responded_callback(message, **self._format_user_dict(user))
        self.assertSetEqual(ids, {1, 2})
        mock_publish.assert_called_once_with(result, subchannels=ids, event=NotificationSseEvent.responded.name)
        self.assert_timestamp_is_not_none(result)
        self.assertDictEqual(result, {'id': message.id, 'username': user.username})

    @patch.object(sse_stream, 'stream')
    def test_notifications_stream_endpoint(self, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        post = self.test_client.post('/api/auth', content_type="application/json",
                                     data=json.dumps(dict(username='******', password='******')), follow_redirects=True)
        key = json.loads(post.get_data(as_text=True))['access_token']
        response = self.test_client.get('/api/streams/messages/notifications?access_token={}'.format(key))
        mock_stream.assert_called_once_with(subchannel=1)
        self.assertEqual(response.status_code, SUCCESS)

    @patch.object(sse_stream, 'stream')
    def test_notifications_stream_endpoint_no_key(self, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        response = self.test_client.get('/api/streams/messages/notifications?access_token=invalid')
        mock_stream.assert_not_called()
        self.assertEqual(response.status_code, 422)
 def setUpClass(cls):
     initialize_test_config()
     cls.cache = MockRedisCacheAdapter()
     cls.controller = ZmqWorkflowCommunicationSender()
     setup_dbs()
예제 #14
0
class TestRedisCacheAdapter(TestCase):

    def setUp(self):
        self.cache = MockRedisCacheAdapter()

    def tearDown(self):
        self.cache.clear()
        self.cache.shutdown()

    def test_set_get(self):
        self.assertTrue(self.cache.set('alice', 'something'))
        self.assertEqual(self.cache.get('alice'), 'something')
        self.assertTrue(self.cache.set('count', 1))
        self.assertEqual(self.cache.get('count'), '1')
        self.assertTrue(self.cache.set('count', 2))
        self.assertEqual(self.cache.get('count'), '2')

    def test_get_key_dne(self):
        self.assertIsNone(self.cache.get('invalid_key'))

    def test_add(self):
        self.assertTrue(self.cache.add('test', 123))
        self.assertEqual(self.cache.get('test'), '123')
        self.assertFalse(self.cache.add('test', 456))
        self.assertEqual(self.cache.get('test'), '123')

    def test_incr(self):
        self.cache.set('count', 1)
        self.assertEqual(self.cache.incr('count'), 2)
        self.assertEqual(self.cache.get('count'), '2')

    def test_incr_multiple(self):
        self.cache.set('uid', 3)
        self.assertEqual(self.cache.incr('uid', amount=10), 13)
        self.assertEqual(self.cache.get('uid'), '13')

    def test_incr_key_dne(self):
        self.assertEqual(self.cache.incr('count'), 1)
        self.assertEqual(self.cache.get('count'), '1')

    def test_incr_multiple_key_dne(self):
        self.assertEqual(self.cache.incr('workflows', amount=10), 10)
        self.assertEqual(self.cache.get('workflows'), '10')

    def test_decr(self):
        self.cache.set('count', 0)
        self.assertEqual(self.cache.decr('count'), -1)
        self.assertEqual(self.cache.get('count'), '-1')

    def test_decr_multiple(self):
        self.cache.set('uid', 3)
        self.assertEqual(self.cache.decr('uid', amount=10), -7)
        self.assertEqual(self.cache.get('uid'), '-7')

    def test_decr_key_dne(self):
        self.assertEqual(self.cache.decr('count'), -1)
        self.assertEqual(self.cache.get('count'), '-1')

    def test_decr_multiple_key_dne(self):
        self.assertEqual(self.cache.decr('workflows', amount=10), -10)
        self.assertEqual(self.cache.get('workflows'), '-10')

    def test_r_push_pop_single_value(self):
        self.cache.rpush('queue', 10)
        self.assertEqual(self.cache.rpop('queue'), '10')

    def test_r_push_pop_multiple_values(self):
        self.cache.rpush('big', 10, 11, 12)
        self.assertEqual(self.cache.rpop('big'), '12')

    def test_l_push_pop_single_value(self):
        self.cache.lpush('queue', 10)
        self.assertEqual(self.cache.lpop('queue'), '10')

    def test_l_push_pop_multiple_values(self):
        self.cache.rpush('big', 10, 11, 12)
        self.assertEqual(self.cache.lpop('big'), '10')
        self.assertEqual(self.cache.rpop('big'), '12')

    def test_subscribe(self):
        sub = self.cache.subscribe('channel1')
        self.assertEqual(sub.channel, 'channel1')

    def test_publish(self):
        sub = self.cache.subscribe('channel_a')
        self.cache.publish('channel_a', '42')
        result = sub._pubsub.get_message()
        self.assertEqual(result['data'], b'42')

    def test_unsubscribe(self):
        sub = self.cache.subscribe('channel_a')
        self.cache.unsubscribe('channel_a')
        result = sub._pubsub.get_message()
        self.assertEqual(result['data'], unsubscribe_message)
예제 #15
0
class TestWorkflowResultsStream(ServerTestCase):

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        workflow_stream.cache = self.cache
        action_stream.cache = self.cache

    def tearDown(self):
        self.cache.clear()
        for status in self.app.running_context.execution_db.session.query(WorkflowStatus).all():
            self.app.running_context.execution_db.session.delete(status)
        self.app.running_context.execution_db.session.commit()

    def assert_and_strip_timestamp(self, data, field='timestamp'):
        timestamp = data.pop(field, None)
        self.assertIsNotNone(timestamp)

    @staticmethod
    def get_sample_action_sender():
        argument_id = str(uuid4())
        action_id = str(uuid4())
        action_execution_id = str(uuid4())
        arguments = [{'name': 'a', 'value': '42'},
                     {'name': 'b', 'reference': argument_id, 'selection': json.dumps(['a', '1'])}]
        return {
            'action_name': 'some_action_name',
            'app_name': 'HelloWorld',
            'id': action_id,
            'name': 'my_name',
            'execution_id': action_execution_id,
            'arguments': arguments
        }

    @staticmethod
    def get_action_kwargs(with_result=False):
        workflow_id = str(uuid4())
        ret = {'workflow': {'execution_id': workflow_id}}
        if with_result:
            ret['data'] = {'result': 'some result'}
        return ret

    def test_format_action_data(self):
        workflow_id = str(uuid4())
        kwargs = {'data': {'workflow': {'execution_id': workflow_id}}}
        sender = self.get_sample_action_sender()
        status = ActionStatusEnum.executing
        result = format_action_data(sender, kwargs, status)
        expected = sender
        expected['action_id'] = expected.pop('id')
        expected['workflow_execution_id'] = workflow_id
        expected['status'] = status.name
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_action_data_with_results(self):
        workflow_id = str(uuid4())
        kwargs = {'data': {'workflow': {'execution_id': workflow_id},
                           'data': {'result': 'some result'}}}
        sender = self.get_sample_action_sender()
        status = ActionStatusEnum.executing
        result = format_action_data_with_results(sender, kwargs, status)
        expected = sender
        expected['action_id'] = expected.pop('id')
        expected['workflow_execution_id'] = workflow_id
        expected['status'] = status.name
        expected['result'] = 'some result'
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def check_action_callback(self, callback, status, event, mock_publish, with_result=False):
        sender = self.get_sample_action_sender()
        kwargs = self.get_action_kwargs(with_result=with_result)
        if not with_result:
            expected = format_action_data(deepcopy(sender), {'data': kwargs}, status)
        else:
            expected = format_action_data_with_results(deepcopy(sender), {'data': kwargs}, status)
        expected.pop('timestamp')
        callback(sender, data=kwargs)
        mock_publish.assert_called_once()
        mock_publish.call_args[0][0].pop('timestamp')
        mock_publish.assert_called_with(expected, event=event)

    @patch.object(action_stream, 'publish')
    def test_action_started_callback(self, mock_publish):
        self.check_action_callback(action_started_callback, ActionStatusEnum.executing, 'started', mock_publish)

    @patch.object(action_stream, 'publish')
    def test_action_ended_callback(self, mock_publish):
        self.check_action_callback(
            action_ended_callback,
            ActionStatusEnum.success,
            'success',
            mock_publish,
            with_result=True)

    @patch.object(action_stream, 'publish')
    def test_action_error_callback(self, mock_publish):
        self.check_action_callback(
            action_error_callback,
            ActionStatusEnum.failure,
            'failure', mock_publish,
            with_result=True)

    @patch.object(action_stream, 'publish')
    def test_action_args_invalid_callback(self, mock_publish):
        self.check_action_callback(
            action_args_invalid_callback,
            ActionStatusEnum.failure,
            'failure',
            mock_publish,
            with_result=True)

    @patch.object(action_stream, 'publish')
    def test_trigger_waiting_data_action_callback(self, mock_publish):
        self.check_action_callback(
            trigger_awaiting_data_action_callback,
            ActionStatusEnum.awaiting_data,
            'awaiting_data',
            mock_publish)

    @staticmethod
    def get_workflow_sender(execution_id=None):
        execution_id = execution_id or str(uuid4())
        workflow_id = str(uuid4())
        return {'execution_id': execution_id, 'id': workflow_id, 'name': 'workflow1'}

    def test_format_workflow_result(self):
        execution_id = str(uuid4())
        workflow_id = str(uuid4())
        sender = {'execution_id': execution_id, 'id': workflow_id, 'name': 'workflow1'}
        result = format_workflow_result(sender, WorkflowStatusEnum.pending)
        self.assert_and_strip_timestamp(result)
        sender['workflow_id'] = sender.pop('id')
        sender['status'] = WorkflowStatusEnum.pending.name
        self.assertDictEqual(result, sender)

    def get_workflow_status(self, workflow_execution_id, status):
        workflow_id = uuid4()
        workflow_status = WorkflowStatus(workflow_execution_id, workflow_id, 'workflow1')
        action_execution_id = uuid4()
        action_id = uuid4()
        self.app.running_context.execution_db.session.add(workflow_status)
        action_status = ActionStatus(action_execution_id, action_id, 'my action', 'the_app', 'the_action')
        self.app.running_context.execution_db.session.add(action_status)
        workflow_status.add_action_status(action_status)
        expected = {
            'execution_id': str(workflow_execution_id),
            'workflow_id': str(workflow_id),
            'name': 'workflow1',
            'status': status.name,
            'current_action': action_status.as_json(summary=True)}
        return expected, workflow_status

    def test_format_workflow_result_with_current_step(self):
        workflow_execution_id = uuid4()
        expected, _ = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.running)

        result = format_workflow_result_with_current_step(workflow_execution_id, WorkflowStatusEnum.running)
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_workflow_result_with_current_step_mismatched_status(self):
        workflow_execution_id = uuid4()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.running)
        status.paused()
        result = format_workflow_result_with_current_step(workflow_execution_id, WorkflowStatusEnum.running)
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def test_format_workflow_result_with_current_step_no_result_found(self):
        workflow_execution_id = uuid4()
        expected = {'execution_id': str(workflow_execution_id), 'status': WorkflowStatusEnum.paused.name}
        result = format_workflow_result_with_current_step(workflow_execution_id, WorkflowStatusEnum.paused)
        self.assert_and_strip_timestamp(result)
        self.assertDictEqual(result, expected)

    def check_workflow_callback(self, callback, sender, status, event, mock_publish, expected=None, **kwargs):
        if not expected:
            expected = format_workflow_result(deepcopy(sender), status)
            expected.pop('timestamp')
        callback(sender, **kwargs)
        mock_publish.assert_called_once()
        self.assert_and_strip_timestamp(mock_publish.call_args[0][0])
        mock_publish.assert_called_with(expected, event=event)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_pending_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_pending_callback,
            sender,
            WorkflowStatusEnum.pending,
            'queued',
            mock_publish)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_started_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_started_callback,
            sender,
            WorkflowStatusEnum.running,
            'started',
            mock_publish)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_paused_callback(self, mock_publish):
        workflow_execution_id = uuid4()
        sender = self.get_workflow_sender(execution_id=str(workflow_execution_id))
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.paused)
        self.check_workflow_callback(
            workflow_paused_callback,
            sender,
            WorkflowStatusEnum.paused,
            'paused',
            mock_publish,
            expected=expected)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_resumed_callback(self, mock_publish):
        workflow_execution_id = uuid4()

        class MockWorkflowSender(object):
            def get_execution_id(self):
                return workflow_execution_id

        sender = MockWorkflowSender()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.running)
        self.check_workflow_callback(
            workflow_resumed_callback,
            sender,
            WorkflowStatusEnum.running,
            'resumed',
            mock_publish,
            expected=expected)

    @patch.object(workflow_stream, 'publish')
    def test_trigger_awaiting_data_workflow_callback(self, mock_publish):
        workflow_execution_id = uuid4()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.awaiting_data)
        self.check_workflow_callback(
            trigger_awaiting_data_workflow_callback,
            None,
            WorkflowStatusEnum.awaiting_data,
            'awaiting_data',
            mock_publish,
            expected=expected,
            data={'workflow': {'execution_id': str(workflow_execution_id)}})

    @patch.object(workflow_stream, 'publish')
    def test_trigger_action_taken_callback(self, mock_publish):
        workflow_execution_id = uuid4()
        expected, status = self.get_workflow_status(workflow_execution_id, WorkflowStatusEnum.pending)
        self.check_workflow_callback(
            trigger_action_taken_callback,
            None,
            WorkflowStatusEnum.pending,
            'triggered',
            mock_publish,
            expected=expected,
            data={'workflow_execution_id': str(workflow_execution_id)})

    @patch.object(workflow_stream, 'publish')
    def test_workflow_aborted_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_aborted_callback,
            sender,
            WorkflowStatusEnum.aborted,
            'aborted',
            mock_publish)

    @patch.object(workflow_stream, 'publish')
    def test_workflow_shutdown_callback(self, mock_publish):
        sender = self.get_workflow_sender()
        self.check_workflow_callback(
            workflow_shutdown_callback,
            sender,
            WorkflowStatusEnum.completed,
            'completed',
            mock_publish)

    def check_stream_endpoint(self, endpoint, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        post = self.test_client.post('/api/auth', content_type="application/json",
                                     data=json.dumps(dict(username='******', password='******')), follow_redirects=True)
        key = json.loads(post.get_data(as_text=True))['access_token']
        response = self.test_client.get('/api/streams/workflowqueue/{}?access_token={}'.format(endpoint, key))
        mock_stream.assert_called_once_with()
        self.assertEqual(response.status_code, SUCCESS)

    def check_stream_endpoint_no_key(self, endpoint, mock_stream):
        mock_stream.return_value = Response('something', status=SUCCESS)
        response = self.test_client.get('/api/streams/workflowqueue/{}?access_token=invalid'.format(endpoint))
        mock_stream.assert_not_called()
        self.assertEqual(response.status_code, 422)

    @patch.object(action_stream, 'stream')
    def test_action_stream_endpoint(self, mock_stream):
        self.check_stream_endpoint('actions', mock_stream)

    @patch.object(workflow_stream, 'stream')
    def test_workflow_stream_endpoint(self, mock_stream):
        self.check_stream_endpoint('workflow_status', mock_stream)

    @patch.object(action_stream, 'stream')
    def test_action_stream_endpoint_invalid_key(self, mock_stream):
        self.check_stream_endpoint_no_key('actions', mock_stream)

    @patch.object(workflow_stream, 'stream')
    def test_workflow_stream_endpoint_invalid_key(self, mock_stream):
        self.check_stream_endpoint_no_key('workflow_status', mock_stream)
 def test_singleton(self):
     cache = MockRedisCacheAdapter()
     self.assertIs(cache, self.cache)
예제 #17
0
 def setUp(self):
     self.cache = MockRedisCacheAdapter()
     workflow_stream.cache = self.cache
     action_stream.cache = self.cache
class TestRedisCacheAdapter(TestCase):

    def setUp(self):
        self.cache = MockRedisCacheAdapter()

    def tearDown(self):
        self.cache.clear()
        self.cache.shutdown()

    def test_singleton(self):
        cache = MockRedisCacheAdapter()
        self.assertIs(cache, self.cache)

    def test_set_get(self):
        self.assertTrue(self.cache.set('alice', 'something'))
        self.assertEqual(self.cache.get('alice'), 'something')
        self.assertTrue(self.cache.set('count', 1))
        self.assertEqual(self.cache.get('count'), '1')
        self.assertTrue(self.cache.set('count', 2))
        self.assertEqual(self.cache.get('count'), '2')

    def test_get_key_dne(self):
        self.assertIsNone(self.cache.get('invalid_key'))

    def test_add(self):
        self.assertTrue(self.cache.add('test', 123))
        self.assertEqual(self.cache.get('test'), '123')
        self.assertFalse(self.cache.add('test', 456))
        self.assertEqual(self.cache.get('test'), '123')

    def test_delete(self):
        self.assertTrue(self.cache.set('alice', 'something'))
        self.cache.delete('alice')
        self.assertIsNone(self.cache.get('alice'))

    def test_delete_dne(self):
        self.cache.delete('alice')
        self.assertIsNone(self.cache.get('alice'))

    def test_incr(self):
        self.cache.set('count', 1)
        self.assertEqual(self.cache.incr('count'), 2)
        self.assertEqual(self.cache.get('count'), '2')

    def test_incr_multiple(self):
        self.cache.set('uid', 3)
        self.assertEqual(self.cache.incr('uid', amount=10), 13)
        self.assertEqual(self.cache.get('uid'), '13')

    def test_incr_key_dne(self):
        self.assertEqual(self.cache.incr('count'), 1)
        self.assertEqual(self.cache.get('count'), '1')

    def test_incr_multiple_key_dne(self):
        self.assertEqual(self.cache.incr('workflows', amount=10), 10)
        self.assertEqual(self.cache.get('workflows'), '10')

    def test_decr(self):
        self.cache.set('count', 0)
        self.assertEqual(self.cache.decr('count'), -1)
        self.assertEqual(self.cache.get('count'), '-1')

    def test_decr_multiple(self):
        self.cache.set('uid', 3)
        self.assertEqual(self.cache.decr('uid', amount=10), -7)
        self.assertEqual(self.cache.get('uid'), '-7')

    def test_decr_key_dne(self):
        self.assertEqual(self.cache.decr('count'), -1)
        self.assertEqual(self.cache.get('count'), '-1')

    def test_decr_multiple_key_dne(self):
        self.assertEqual(self.cache.decr('workflows', amount=10), -10)
        self.assertEqual(self.cache.get('workflows'), '-10')

    def test_r_push_pop_single_value(self):
        self.cache.rpush('queue', 10)
        self.assertEqual(self.cache.rpop('queue'), '10')

    def test_r_push_pop_multiple_values(self):
        self.cache.rpush('big', 10, 11, 12)
        self.assertEqual(self.cache.rpop('big'), '12')

    def test_l_push_pop_single_value(self):
        self.cache.lpush('queue', 10)
        self.assertEqual(self.cache.lpop('queue'), '10')

    def test_l_push_pop_multiple_values(self):
        self.cache.rpush('big', 10, 11, 12)
        self.assertEqual(self.cache.lpop('big'), '10')
        self.assertEqual(self.cache.rpop('big'), '12')

    def test_scan_no_pattern(self):
        keys = ('a', 'b', 'c', 'd')
        for i, key in enumerate(keys):
            self.cache.set(key, i)
        ret_keys = self.cache.scan()
        self.assertSetEqual(set(ret_keys), set(keys))

    def test_scan_with_pattern(self):
        keys = ('1.a', '2.a', '3.b', 'd')
        for i, key in enumerate(keys):
            self.cache.set(key, i)
        ret_keys = self.cache.scan('*.a')
        self.assertSetEqual(set(ret_keys), {'1.a', '2.a'})

    def test_exists(self):
        key = 'abc'
        self.assertFalse(self.cache.exists(key))
        self.cache.set(key, 42)
        self.assertTrue(self.cache.exists(key))

    def test_subscribe(self):
        sub = self.cache.subscribe('channel1')
        self.assertEqual(sub.channel, 'channel1')

    def test_publish(self):
        sub = self.cache.subscribe('channel_a')
        self.cache.publish('channel_a', '42')
        result = sub._pubsub.get_message()
        if result['data'] == 1:
            result = sub._pubsub.get_message()
        self.assertEqual(result['data'], b'42')

    def test_unsubscribe(self):
        sub = self.cache.subscribe('channel_a')
        self.cache.unsubscribe('channel_a')
        result = sub._pubsub.get_message()
        if result['data'] == 1:
            result = sub._pubsub.get_message()
        self.assertEqual(result['data'], unsubscribe_message)

    def test_lock(self):
        r = self.cache.lock('myname', timeout=4.5, sleep=0.5, blocking_timeout=1.6)
        self.assertEqual(r.name, 'myname')
예제 #19
0
class TestSimpleFilteredSseStream(TestCase):
    @classmethod
    def setUpClass(cls):
        patch_all()

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        self.channel = 'channel1'
        self.stream = FilteredSseStream(self.channel, self.cache)

    def tearDown(self):
        self.cache.clear()

    def test_init(self):
        self.assertEqual(self.stream.channel, self.channel)
        self.assertEqual(self.stream.cache, self.cache)

    def test_create_channel_name(self):
        self.assertEqual(self.stream.create_subchannel_name('a'),
                         '{}.a'.format(self.channel))
        self.assertEqual(self.stream.create_subchannel_name(14),
                         '{}.14'.format(self.channel))

    def assert_header_in_response(self, response, header, value):
        header_tuple = next(
            (header_ for header_ in response.headers if header_[0] == header),
            None)
        self.assertIsNotNone(header_tuple)
        self.assertEqual(header_tuple[1], value)

    def test_stream_default_headers(self):
        resp = self.stream.stream(subchannel='a')
        self.assert_header_in_response(resp, 'Connection', 'keep-alive')
        self.assert_header_in_response(resp, 'Cache-Control', 'no-cache')
        self.assert_header_in_response(resp, 'Content-Type',
                                       'text/event-stream; charset=utf-8')

    def test_stream_custom_headers(self):
        resp = self.stream.stream(subchannel='a',
                                  headers={
                                      'x-custom': 'yes',
                                      'Cache-Control': 'no-store'
                                  })
        self.assert_header_in_response(resp, 'Connection', 'keep-alive')
        self.assert_header_in_response(resp, 'Cache-Control', 'no-store')
        self.assert_header_in_response(resp, 'Content-Type',
                                       'text/event-stream; charset=utf-8')
        self.assert_header_in_response(resp, 'x-custom', 'yes')

    def test_send(self):
        @self.stream.push('event1')
        def pusher(a, ev, sub):
            gevent.sleep(0.1)
            return {'a': a}, sub, ev

        subs = ('aaa', 'bbb')

        result = {sub: [] for sub in subs}

        def listen(sub):
            for event in self.stream.send(subchannel=sub):
                result[sub].append(event)

        base_args = [('event1', 1), ('event2', 2)]
        args = {
            sub: [(event, data + i) for (event, data) in base_args]
            for i, sub in enumerate(subs)
        }

        def publish(sub):
            for event, data in args[sub]:
                pusher(data, event, sub)
            self.stream.unsubscribe(sub)

        sses = {
            sub: [SseEvent(event, {'a': arg}) for event, arg in args[sub]]
            for sub in subs
        }
        formatted_sses = {
            sub: [sse.format(i + 1) for i, sse in enumerate(sse_vals)]
            for sub, sse_vals in sses.items()
        }

        listen_threads = [gevent.spawn(listen, sub) for sub in subs]
        publish_threads = [gevent.spawn(publish, sub) for sub in subs]
        gevent.sleep(0.1)
        gevent.joinall(listen_threads, timeout=2)
        gevent.joinall(publish_threads, timeout=2)
        for sub in subs:
            self.assertListEqual(result[sub], formatted_sses[sub])

    def test_send_publish_multiple(self):

        subs = ('a', 'bbb')

        @self.stream.push('event1')
        def pusher(a, ev):
            gevent.sleep(0.1)
            return {'a': a}, subs, ev

        result = {sub: [] for sub in subs}

        def listen(sub):
            for event in self.stream.send(subchannel=sub):
                result[sub].append(event)

        base_args = [('event1', 1), ('event2', 2)]

        def publish():
            for event, data in base_args:
                pusher(data, event)
            for sub in subs:
                self.stream.unsubscribe(sub)

        sses = {
            sub: [SseEvent(event, {'a': arg}) for event, arg in base_args]
            for sub in subs
        }
        formatted_sses = {
            sub: [sse.format(i + 1) for i, sse in enumerate(sse_vals)]
            for sub, sse_vals in sses.items()
        }

        listen_threads = [gevent.spawn(listen, sub) for sub in subs]
        publish_thread = gevent.spawn(publish)
        gevent.sleep(0.1)
        gevent.joinall(listen_threads, timeout=2)
        publish_thread.join(timeout=2)
        for sub in subs:
            self.assertListEqual(result[sub], formatted_sses[sub])

    def test_send_with_retry(self):
        @self.stream.push('event1')
        def pusher(a, ev, sub):
            gevent.sleep(0.1)
            return {'a': a}, sub, ev

        subs = ('a', 'b')

        result = {'a': [], 'b': []}

        def listen(sub):
            for event in self.stream.send(subchannel=sub, retry=50):
                result[sub].append(event)

        base_args = [('event1', 1), ('event2', 2)]
        args = {
            sub: [(event, data + i) for (event, data) in base_args]
            for i, sub in enumerate(subs)
        }

        def publish(sub):
            for event, data in args[sub]:
                pusher(data, event, sub)
            self.stream.unsubscribe(sub)

        sses = {
            sub: [SseEvent(event, {'a': arg}) for event, arg in args[sub]]
            for sub in subs
        }
        formatted_sses = {
            sub:
            [sse.format(i + 1, retry=50) for i, sse in enumerate(sse_vals)]
            for sub, sse_vals in sses.items()
        }

        listen_threads = [gevent.spawn(listen, sub) for sub in subs]
        publish_threads = [gevent.spawn(publish, sub) for sub in subs]
        gevent.sleep(0.1)
        gevent.joinall(listen_threads, timeout=2)
        gevent.joinall(publish_threads, timeout=2)
        for sub in subs:
            self.assertListEqual(result[sub], formatted_sses[sub])
 def setUp(self):
     self.cache = MockRedisCacheAdapter()
예제 #21
0
 def setUp(self):
     self.cache = MockRedisCacheAdapter()
     workflow_stream.cache = self.cache
     action_stream.cache = self.cache
예제 #22
0
class TestSimpleFilteredSseStream(TestCase):
    @classmethod
    def setUpClass(cls):
        patch_all()

    def setUp(self):
        self.cache = MockRedisCacheAdapter()
        self.channel = 'channel1'
        self.stream = FilteredSseStream(self.channel, self.cache)

    def tearDown(self):
        self.cache.clear()

    def test_init(self):
        self.assertEqual(self.stream.channel, self.channel)
        self.assertEqual(self.stream.cache, self.cache)

    def test_create_channel_name(self):
        self.assertEqual(self.stream.create_subchannel_name('a'), '{}.a'.format(self.channel))
        self.assertEqual(self.stream.create_subchannel_name(14), '{}.14'.format(self.channel))

    def assert_header_in_response(self, response, header, value):
        header_tuple = next((header_ for header_ in response.headers if header_[0] == header), None)
        self.assertIsNotNone(header_tuple)
        self.assertEqual(header_tuple[1], value)

    def test_stream_default_headers(self):
        resp = self.stream.stream(subchannel='a')
        self.assert_header_in_response(resp, 'Connection', 'keep-alive')
        self.assert_header_in_response(resp, 'Cache-Control', 'no-cache')
        self.assert_header_in_response(resp, 'Content-Type', 'text/event-stream; charset=utf-8')

    def test_stream_custom_headers(self):
        resp = self.stream.stream(subchannel='a', headers={'x-custom': 'yes', 'Cache-Control': 'no-store'})
        self.assert_header_in_response(resp, 'Connection', 'keep-alive')
        self.assert_header_in_response(resp, 'Cache-Control', 'no-store')
        self.assert_header_in_response(resp, 'Content-Type', 'text/event-stream; charset=utf-8')
        self.assert_header_in_response(resp, 'x-custom', 'yes')

    def test_send(self):

        @self.stream.push('event1')
        def pusher(a, ev, sub):
            gevent.sleep(0.1)
            return {'a': a}, sub, ev

        subs = ('aaa', 'bbb')

        result = {sub: [] for sub in subs}

        def listen(sub):
            for event in self.stream.send(subchannel=sub):
                result[sub].append(event)

        base_args = [('event1', 1), ('event2', 2)]
        args = {sub: [(event, data + i) for (event, data) in base_args] for i, sub in enumerate(subs)}

        def publish(sub):
            for event, data in args[sub]:
                pusher(data, event, sub)
            self.stream.unsubscribe(sub)

        sses = {sub: [SseEvent(event, {'a': arg}) for event, arg in args[sub]] for sub in subs}
        formatted_sses = {sub: [sse.format(i + 1) for i, sse in enumerate(sse_vals)] for sub, sse_vals in sses.items()}

        listen_threads = [gevent.spawn(listen, sub) for sub in subs]
        publish_threads = [gevent.spawn(publish, sub) for sub in subs]
        gevent.sleep(0.1)
        gevent.joinall(listen_threads, timeout=2)
        gevent.joinall(publish_threads, timeout=2)
        for sub in subs:
            self.assertListEqual(result[sub], formatted_sses[sub])

    def test_send_publish_multiple(self):

        subs = ('a', 'bbb')

        @self.stream.push('event1')
        def pusher(a, ev):
            gevent.sleep(0.1)
            return {'a': a}, subs, ev

        result = {sub: [] for sub in subs}

        def listen(sub):
            for event in self.stream.send(subchannel=sub):
                result[sub].append(event)

        base_args = [('event1', 1), ('event2', 2)]

        def publish():
            for event, data in base_args:
                pusher(data, event)
            for sub in subs:
                self.stream.unsubscribe(sub)

        sses = {sub: [SseEvent(event, {'a': arg}) for event, arg in base_args] for sub in subs}
        formatted_sses = {sub: [sse.format(i + 1) for i, sse in enumerate(sse_vals)] for sub, sse_vals in sses.items()}

        listen_threads = [gevent.spawn(listen, sub) for sub in subs]
        publish_thread = gevent.spawn(publish)
        gevent.sleep(0.1)
        gevent.joinall(listen_threads, timeout=2)
        publish_thread.join(timeout=2)
        for sub in subs:
            self.assertListEqual(result[sub], formatted_sses[sub])

    def test_send_with_retry(self):

        @self.stream.push('event1')
        def pusher(a, ev, sub):
            gevent.sleep(0.1)
            return {'a': a}, sub, ev

        subs = ('a', 'b')

        result = {'a': [], 'b': []}

        def listen(sub):
            for event in self.stream.send(subchannel=sub, retry=50):
                result[sub].append(event)

        base_args = [('event1', 1), ('event2', 2)]
        args = {sub: [(event, data + i) for (event, data) in base_args] for i, sub in enumerate(subs)}

        def publish(sub):
            for event, data in args[sub]:
                pusher(data, event, sub)
            self.stream.unsubscribe(sub)

        sses = {sub: [SseEvent(event, {'a': arg}) for event, arg in args[sub]] for sub in subs}
        formatted_sses = {sub: [sse.format(i + 1, retry=50) for i, sse in enumerate(sse_vals)] for sub, sse_vals in
                          sses.items()}

        listen_threads = [gevent.spawn(listen, sub) for sub in subs]
        publish_threads = [gevent.spawn(publish, sub) for sub in subs]
        gevent.sleep(0.1)
        gevent.joinall(listen_threads, timeout=2)
        gevent.joinall(publish_threads, timeout=2)
        for sub in subs:
            self.assertListEqual(result[sub], formatted_sses[sub])