Пример #1
0
 def test_deserialize(self):
     job_1, job_2, job_3, job_4, job_5 = (Job('spam'), Job('eggs'),
                                          Job('ham'), Job('python'),
                                          Job('answer_42'))
     pipeline = Pipeline({job_1: job_2, job_2: (job_3, job_4), job_5: None},
                         data={'key': 42})
     serialized = pipeline.serialize()
     new_pipeline = Pipeline.deserialize(serialized)
     self.assertEqual(pipeline, new_pipeline)
     self.assertEqual(serialized, new_pipeline.serialize())
Пример #2
0
    def test_str_and_save_dot(self):
        pipeline = Pipeline({Job('A'): Job('B'), Job('C'): None})
        result = str(pipeline)
        expected = dedent('''
        digraph graphname {
            "A";
            "C";
            "B";

            "A" -> "B";
            "C" -> "(None)";
        }
        ''').strip()
        self.assertEqual(result, expected)

        pipeline = Pipeline({(Job('A'), Job('B'), Job('C')): [Job('D')],
                             Job('E'): (Job('B'), Job('F'))})
        result = str(pipeline)
        expected = dedent('''
        digraph graphname {
            "A";
            "C";
            "B";
            "E";
            "D";
            "F";

            "A" -> "D";
            "B" -> "D";
            "C" -> "D";
            "E" -> "B";
            "E" -> "F";
        }
        ''').strip()

        self.assertEqual(result, expected)
        temp_file = NamedTemporaryFile(delete=False)
        temp_file.close()
        pipeline.save_dot(temp_file.name)
        temp_file = open(temp_file.name)
        file_contents = temp_file.read()
        temp_file.close()
        self.assertEqual(expected + '\n', file_contents)
        unlink(temp_file.name)
Пример #3
0
    def test_serialize(self):
        job_1, job_2, job_3, job_4 = (Job('spam'), Job('eggs'), Job('ham'),
                                      Job('python'))
        pipeline = Pipeline({job_1: job_2, job_2: (job_3, job_4)})
        result = pipeline.serialize()
        expected = {'graph': ((job_1.serialize(), job_2.serialize()),
                              (job_2.serialize(), job_3.serialize()),
                              (job_2.serialize(), job_4.serialize())),
                    'data': None}
        expected = tuple(expected.items())

        result = dict(result)
        expected = dict(expected)
        result['graph'] = dict(result['graph'])
        expected['graph'] = dict(expected['graph'])
        self.assertEqual(result, expected)

        pipeline = Pipeline({job_1: job_2}, data={'python': 42})
        self.assertEqual(pipeline, Pipeline.deserialize(pipeline.serialize()))
Пример #4
0
    def test_should_send_add_pipeline_with_serialized_pipeline(self):
        result, pool = run_in_parallel(send_pipeline)
        message = self.api.recv_json()
        received = Pipeline.deserialize(message['pipeline']).serialize()
        expected = self.pipeline.serialize()
        self.assertEqual(set(message.keys()), set(['command', 'pipeline']))
        self.assertEqual(message['command'], 'add pipeline')
        self.assertEqual(received, expected)

        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        result.get()
        pool.terminate()
Пример #5
0
 def setUp(self):
     self.context = zmq.Context()
     self.start_router_sockets()
     self.pipeline = Pipeline({Job(u'worker_1'): Job(u'worker_2'),
                               Job(u'worker_2'): Job(u'worker_3')})
Пример #6
0
class PipelineManagerTest(unittest.TestCase):
    def setUp(self):
        self.context = zmq.Context()
        self.start_router_sockets()
        self.pipeline = Pipeline({Job(u'worker_1'): Job(u'worker_2'),
                                  Job(u'worker_2'): Job(u'worker_3')})

    def tearDown(self):
        self.close_sockets()
        self.context.term()

    def start_router_sockets(self):
        self.api = self.context.socket(zmq.REP)
        self.broadcast = self.context.socket(zmq.PUB)
        self.api.bind(API_ADDRESS)
        self.broadcast.bind(BROADCAST_ADDRESS)

    def close_sockets(self):
        self.api.close()
        self.broadcast.close()

    def test_repr(self):
        pipeline_manager = PipelineManager(api=API_ADDRESS,
                                           broadcast=BROADCAST_ADDRESS)
        pipeline_ids = [uuid4().hex for i in range(10)]
        pipeline_ids_copy = pipeline_ids[:]
        pipeline_manager.send_api_request = lambda x: None
        pipeline_manager.get_api_reply = \
                lambda: {'pipeline id': pipeline_ids.pop()}
        pipelines = [Pipeline({Job('A', data={'index': i}): Job('B')}) \
                     for i in range(10)]
        for pipeline in pipelines:
            pipeline_manager.start(pipeline)

        result = repr(pipeline_manager)
        self.assertEqual(result, '<PipelineManager: 10 submitted, 0 finished>')

        messages = ['pipeline finished: id={}, duration=0.1'.format(pipeline_id)
                    for pipeline_id in pipeline_ids_copy[:3]]
        poll = [False, True, True, True]
        def new_poll(timeout):
            return poll.pop()
        def new_broadcast_receive():
            return messages.pop()
        pipeline_manager.broadcast_poll = new_poll
        pipeline_manager.broadcast_receive = new_broadcast_receive
        pipeline_manager.update(0.1)
        result = repr(pipeline_manager)
        self.assertEqual(result, '<PipelineManager: 10 submitted, 3 finished>')

    def test_should_send_add_pipeline_with_serialized_pipeline(self):
        result, pool = run_in_parallel(send_pipeline)
        message = self.api.recv_json()
        received = Pipeline.deserialize(message['pipeline']).serialize()
        expected = self.pipeline.serialize()
        self.assertEqual(set(message.keys()), set(['command', 'pipeline']))
        self.assertEqual(message['command'], 'add pipeline')
        self.assertEqual(received, expected)

        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        result.get()
        pool.terminate()

    def test_should_save_pipeline_id_on_pipeline_object(self):
        result, pool = run_in_parallel(send_pipeline)
        message = self.api.recv_json()
        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        received = result.get()
        pool.terminate()
        self.assertEqual(received, (None, pipeline_id, pipeline_id))

    def test_should_subscribe_to_broadcast_to_wait_for_finished_pipeline(self):
        result, pool = run_in_parallel(send_pipeline_and_wait_finished)
        pipeline_ids = []
        for i in range(10):
            message = self.api.recv_json()
            pipeline_id = uuid4().hex
            self.api.send_json({'answer': 'pipeline accepted',
                                'pipeline id': pipeline_id})
            pipeline_ids.append(pipeline_id)
        sleep(1)
        for pipeline_id in pipeline_ids:
            self.broadcast.send('pipeline finished: id={}, duration=1.23456'\
                                .format(pipeline_id))
        received = result.get()
        pool.terminate()
        self.assertEqual(received['duration'], 1.23456)
        self.assertTrue(received['real_duration'] > 1)
        self.assertTrue(received['finished_pipelines'], 10)
        self.assertTrue(received['started_pipelines'], 10)

    def test_should_raise_ValueError_in_some_cases(self):
        result, pool = run_in_parallel(verify_PipelineManager_exceptions)
        message = self.api.recv_json()
        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        start_time = time()
        received = result.get()
        pool.terminate()
        self.assertTrue(received['raise_1'])
        self.assertTrue(received['raise_2'])
        started_at = received['started_at']
        self.assertTrue(start_time - 0.1 <= started_at <= start_time + 0.1)

    def test_should_return_all_pipelines(self):
        pipeline_manager = PipelineManager(api=API_ADDRESS,
                                           broadcast=BROADCAST_ADDRESS)
        pipeline_manager.send_api_request = lambda x: None
        pipeline_manager.get_api_reply = lambda: {'pipeline id': uuid4().hex}
        iterations = 10
        pipelines = []
        for i in range(iterations):
            pipeline = Pipeline({Job(u'worker_1'): Job(u'worker_2'),
                                 Job(u'worker_2'): Job(u'worker_3')},
                                data={'index': i})
            pipeline_manager.start(pipeline)
            pipelines.append(pipeline)
        self.assertEqual(set(pipeline_manager.pipelines), set(pipelines))
Пример #7
0
 def send_pipeline(self, pipeline_definition):
     pipeline = Pipeline(pipeline_definition['graph'],
                         data=pipeline_definition['data'])
     self.api.send_json({'pipeline': pipeline.serialize(),
                         'pipeline id': pipeline_definition['pipeline id']})
Пример #8
0
class PipelineManagerTest(unittest.TestCase):
    def setUp(self):
        self.context = zmq.Context()
        self.start_router_sockets()
        self.pipeline = Pipeline({Job(u'worker_1'): Job(u'worker_2'),
                                  Job(u'worker_2'): Job(u'worker_3')})

    def tearDown(self):
        self.close_sockets()
        self.context.term()

    def start_router_sockets(self):
        self.api = self.context.socket(zmq.REP)
        self.broadcast = self.context.socket(zmq.PUB)
        self.api.bind('tcp://127.0.0.1:5550')
        self.broadcast.bind('tcp://127.0.0.1:5551')

    def close_sockets(self):
        self.api.close()
        self.broadcast.close()

    def test_should_send_add_pipeline_with_serialized_pipeline(self):
        result, pool = run_in_parallel(send_pipeline)
        message = self.api.recv_json()
        received = Pipeline.deserialize(message['pipeline']).serialize()
        expected = self.pipeline.serialize()
        self.assertEqual(set(message.keys()), set(['command', 'pipeline']))
        self.assertEqual(message['command'], 'add pipeline')
        self.assertEqual(received, expected)

        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        result.get()
        pool.terminate()

    def test_should_save_pipeline_id_on_pipeline_object(self):
        result, pool = run_in_parallel(send_pipeline)
        message = self.api.recv_json()
        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        received = result.get()
        pool.terminate()
        self.assertEqual(received, (None, pipeline_id, pipeline_id))

    def test_should_subscribe_to_broadcast_to_wait_for_finished_pipeline(self):
        result, pool = run_in_parallel(send_pipeline_and_wait_finished)
        message = self.api.recv_json()
        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        sleep(1)
        self.broadcast.send('pipeline finished: id={}, duration=1.23456'\
                            .format(pipeline_id))
        received = result.get()
        pool.terminate()
        self.assertEqual(received['duration'], 1.23456)
        self.assertTrue(received['real_duration'] > 1)

    def test_should_raise_ValueError_in_some_cases(self):
        result, pool = run_in_parallel(verify_PipelineManager_exceptions)
        message = self.api.recv_json()
        pipeline_id = uuid4().hex
        self.api.send_json({'answer': 'pipeline accepted',
                            'pipeline id': pipeline_id})
        start_time = time()
        received = result.get()
        pool.terminate()
        self.assertTrue(received['raise_1'])
        self.assertTrue(received['raise_2'])
        started_at = received['started_at']
        self.assertTrue(start_time - 0.1 <= started_at <= start_time + 0.1)