Ejemplo n.º 1
0
 def _create_writer_with_fail(self, session_id, chunk_key, *args, **kwargs):
     if chunk_key == fail_key:
         if kwargs.get('_promise', True):
             return promise.finished(*build_exc_info(ValueError), **dict(_accept=False))
         else:
             raise ValueError
     return old_create_writer(self, session_id, chunk_key, *args, **kwargs)
Ejemplo n.º 2
0
    def run_test_cache(self):
        session_id = str(uuid.uuid4())

        chunk_holder_ref = self.promise_ref(ChunkHolderActor.default_name())
        chunk_store = self._chunk_store

        data_list = []
        for _ in range(9):
            data_id = str(uuid.uuid4())
            data = np.random.randint(0, 32767, (655360, ), np.int16)
            data_list.append((data_id, data))

        def _put_chunk(data_key, data, *_):
            def _handle_reject(*exc):
                if issubclass(exc[0], NoDataToSpill):
                    return
                six.reraise(*exc)

            try:
                ref = chunk_store.put(session_id, data_key, data)
                chunk_holder_ref.register_chunk(session_id, data_key)
                self.ctx.sleep(0.5)
                del ref
            except StoreFull:
                return chunk_holder_ref.spill_size(calc_data_size(data) * 2, _promise=True) \
                    .then(partial(_put_chunk, data_key, data), _handle_reject)

        data_promises = []
        for data_id, data in data_list:
            data_promises.append(promise.finished().then(
                partial(_put_chunk, data_id, data)))

        def assert_true(v):
            assert v

        last_id = data_list[-1][0]
        p = promise.all_(data_promises) \
            .then(lambda *_: assert_true(chunk_store.contains(session_id, last_id))) \
            .then(lambda *_: ensure_chunk(self, session_id, last_id)) \
            .then(lambda *_: assert_true(chunk_store.contains(session_id, last_id)))

        first_id = data_list[0][0]
        p = p.then(lambda *_: assert_true(not chunk_store.contains(session_id, first_id))) \
            .then(lambda *_: ensure_chunk(self, session_id, first_id)) \
            .then(lambda *_: assert_true(chunk_store.contains(session_id, first_id)))

        p = p.then(lambda *_: chunk_holder_ref.unregister_chunk(session_id, first_id)) \
            .then(lambda *_: self._plasma_client.evict(128)) \
            .then(lambda *_: assert_true(not chunk_store.contains(session_id, first_id)))

        p = p.then(lambda *_: chunk_holder_ref.unregister_chunk(session_id, last_id)) \
            .then(lambda *_: self._plasma_client.evict(128)) \
            .then(lambda *_: assert_true(not chunk_store.contains(session_id, last_id)))

        p.catch(lambda *exc: setattr(self, '_exc_info', exc)) \
            .then(lambda *_: setattr(self, '_finished', True))
Ejemplo n.º 3
0
 def _mock_load_from(*_, **__):
     return promise.finished(*build_exc_info(SystemError),
                             _accept=False)
Ejemplo n.º 4
0
    def testDispatch(self, *_):
        call_records = dict()
        group_size = 4

        mock_scheduler_addr = f'127.0.0.1:{get_next_port()}'
        with create_actor_pool(n_process=1,
                               backend='gevent',
                               address=mock_scheduler_addr) as pool:
            dispatch_ref = pool.create_actor(DispatchActor,
                                             uid=DispatchActor.default_uid())
            # actors of g1
            [
                pool.create_actor(TaskActor, 'g1', call_records)
                for _ in range(group_size)
            ]
            [
                pool.create_actor(TaskActor, 'g2', call_records)
                for _ in range(group_size)
            ]

            self.assertEqual(len(dispatch_ref.get_slots('g1')), group_size)
            self.assertEqual(len(dispatch_ref.get_slots('g2')), group_size)
            self.assertEqual(len(dispatch_ref.get_slots('g3')), 0)

            self.assertEqual(dispatch_ref.get_hash_slot('g1', 'hash_str'),
                             dispatch_ref.get_hash_slot('g1', 'hash_str'))

            dispatch_ref.acquire_free_slot('g1',
                                           callback=(('NonExist',
                                                      mock_scheduler_addr),
                                                     '_non_exist', {}))
            self.assertEqual(dispatch_ref.get_free_slots_num().get('g1'),
                             group_size)

            # tasks within [0, group_size - 1] will run almost simultaneously,
            # while the last one will be delayed due to lack of slots

            delay = 1

            with self.run_actor_test(pool) as test_actor:
                p = promise.finished()
                _dispatch_ref = test_actor.promise_ref(
                    DispatchActor.default_uid())

                def _call_on_dispatched(uid, key=None):
                    if uid is None:
                        call_records[key] = 'NoneUID'
                    else:
                        test_actor.promise_ref(uid).queued_call(key,
                                                                delay,
                                                                _tell=True,
                                                                _wait=False)

                for idx in range(group_size + 1):
                    p = p.then(lambda *_: _dispatch_ref.acquire_free_slot('g1', _promise=True)) \
                        .then(partial(_call_on_dispatched, key=f'{idx}_1')) \
                        .then(lambda *_: _dispatch_ref.acquire_free_slot('g2', _promise=True)) \
                        .then(partial(_call_on_dispatched, key=f'{idx}_2'))

                p.then(lambda *_: _dispatch_ref.acquire_free_slot('g3', _promise=True)) \
                    .then(partial(_call_on_dispatched, key='N_1')) \
                    .then(lambda *_: test_actor.set_result(None))

            self.get_result(20)

            self.assertEqual(call_records['N_1'], 'NoneUID')
            self.assertLess(
                sum(
                    abs(call_records[f'{idx}_1'] - call_records['0_1'])
                    for idx in range(group_size)), delay * 0.5)
            self.assertGreater(
                call_records[f'{group_size}_1'] - call_records['0_1'],
                delay * 0.5)
            self.assertLess(
                call_records[f'{group_size}_1'] - call_records['0_1'],
                delay * 1.5)

            dispatch_ref.destroy()
Ejemplo n.º 5
0
    def testPromise(self):
        promises = weakref.WeakValueDictionary()
        req_queue = Queue()
        value_list = []

        time_unit = 0.1

        def test_thread_body():
            while True:
                idx, v, success = req_queue.get()
                if v is None:
                    break
                value_list.append(('thread_body', v))
                time.sleep(time_unit)
                promises[idx].step_next([(v, ), dict(_accept=success)])

        try:
            thread = threading.Thread(target=test_thread_body)
            thread.daemon = True
            thread.start()

            def gen_promise(value, accept=True):
                value_list.append(('gen_promise', value))
                p = promise.Promise()
                promises[p.id] = p
                req_queue.put((p.id, value + 1, accept))
                return p

            # simple promise call
            value_list = []
            p = gen_promise(0) \
                .then(lambda v: gen_promise(v)) \
                .then(lambda v: gen_promise(v))
            p.wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 2),
                                              ('thread_body', 3)])

            # continue accepted call with then
            value_list = []
            p.then(lambda *_: gen_promise(0)) \
                .then(lambda v: gen_promise(v)) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2)])

            # immediate error
            value_list = []
            p = promise.finished() \
                .then(lambda *_: 5 / 0)
            p.catch(lambda *_: gen_promise(0)) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1)])

            # chained errors
            value_list = []
            p = promise.finished(_accept=False) \
                .catch(lambda *_: 1 / 0) \
                .catch(lambda *_: 2 / 0) \
                .catch(lambda *_: gen_promise(0)) \
                .catch(lambda *_: gen_promise(1))
            p.wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1)])

            # continue error call
            value_list = []
            p = gen_promise(0) \
                .then(lambda *_: 5 / 0) \
                .then(lambda *_: gen_promise(2))
            time.sleep(0.5)
            value_list = []
            p.catch(lambda *_: gen_promise(0)) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v).then(lambda x: x + 1)) \
                .then(lambda v: gen_promise(v)) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 3),
                                              ('thread_body', 4)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v)) \
                .then(lambda v: gen_promise(v, False)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 2),
                                              ('thread_body', 3),
                                              ('catch', 3)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v, False).then(lambda x: x + 1)) \
                .then(lambda v: gen_promise(v)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('catch', 2)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v)) \
                .then(lambda v: v + 1) \
                .then(lambda v: gen_promise(v, False)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 3),
                                              ('thread_body', 4),
                                              ('catch', 4)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v)) \
                .catch(lambda v: gen_promise(v)) \
                .catch(lambda v: gen_promise(v)) \
                .then(lambda v: gen_promise(v, False)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 2),
                                              ('thread_body', 3),
                                              ('catch', 3)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v)) \
                .catch(lambda v: gen_promise(v)) \
                .catch(lambda v: gen_promise(v)) \
                .then(lambda v: gen_promise(v, False)) \
                .then(lambda v: gen_promise(v), lambda v: gen_promise(v + 1, False)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 2),
                                              ('thread_body', 3),
                                              ('gen_promise', 4),
                                              ('thread_body', 5),
                                              ('catch', 5)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v)) \
                .catch(lambda v: gen_promise(v)) \
                .catch(lambda v: gen_promise(v)) \
                .then(lambda v: gen_promise(v, False)) \
                .then(lambda v: gen_promise(v), lambda v: _raise_exception(ValueError)) \
                .catch(lambda *_: value_list.append(('catch',))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 2),
                                              ('thread_body', 3), ('catch', )])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v, False)) \
                .catch(lambda v: gen_promise(v, False)) \
                .catch(lambda v: gen_promise(v)) \
                .then(lambda v: gen_promise(v)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('gen_promise', 2),
                                              ('thread_body', 3),
                                              ('gen_promise', 3),
                                              ('thread_body', 4),
                                              ('gen_promise', 4),
                                              ('thread_body', 5)])

            value_list = []
            gen_promise(0) \
                .then(lambda v: gen_promise(v, False)) \
                .then(lambda v: gen_promise(v)) \
                .catch(lambda v: value_list.append(('catch', v))) \
                .wait()
            self.assertListEqual(value_list, [('gen_promise', 0),
                                              ('thread_body', 1),
                                              ('gen_promise', 1),
                                              ('thread_body', 2),
                                              ('catch', 2)])
        finally:
            self.assertDictEqual(promise._promise_pool, {})
            req_queue.put((None, None, None))
Ejemplo n.º 6
0
    def testReceiver(self):
        pool_addr = 'localhost:%d' % get_next_port()
        options.worker.spill_directory = tempfile.mkdtemp(
            prefix='mars_test_receiver_')
        session_id = str(uuid.uuid4())

        mock_data = np.array([1, 2, 3, 4])
        serialized_arrow_data = dataserializer.serialize(mock_data)
        data_size = serialized_arrow_data.total_bytes
        serialized_mock_data = serialized_arrow_data.to_buffer()
        serialized_crc32 = zlib.crc32(serialized_arrow_data.to_buffer())

        chunk_key1 = str(uuid.uuid4())
        chunk_key2 = str(uuid.uuid4())
        chunk_key3 = str(uuid.uuid4())
        chunk_key4 = str(uuid.uuid4())
        chunk_key5 = str(uuid.uuid4())
        chunk_key6 = str(uuid.uuid4())
        chunk_key7 = str(uuid.uuid4())
        chunk_key8 = str(uuid.uuid4())

        with start_transfer_test_pool(
                address=pool_addr,
                plasma_size=self.plasma_storage_size) as pool:
            receiver_ref = pool.create_actor(ReceiverActor,
                                             uid=str(uuid.uuid4()))

            with self.run_actor_test(pool) as test_actor:
                storage_client = test_actor.storage_client

                # check_status on receiving and received
                self.assertEqual(
                    receiver_ref.check_status(session_id, chunk_key1),
                    ReceiveStatus.NOT_STARTED)

                self.waitp(
                    storage_client.create_writer(
                        session_id, chunk_key1,
                        serialized_arrow_data.total_bytes,
                        [DataStorageDevice.DISK
                         ]).then(lambda writer: promise.finished().then(
                             lambda *_: writer.write(serialized_arrow_data)).
                                 then(lambda *_: writer.close())))
                self.assertEqual(
                    receiver_ref.check_status(session_id, chunk_key1),
                    ReceiveStatus.RECEIVED)
                storage_client.delete(session_id, chunk_key1)

                self.waitp(
                    storage_client.put_object(
                        session_id, chunk_key1, mock_data,
                        [DataStorageDevice.SHARED_MEMORY]))

                self.assertEqual(
                    receiver_ref.check_status(session_id, chunk_key1),
                    ReceiveStatus.RECEIVED)

                receiver_ref_p = test_actor.promise_ref(receiver_ref)

                # cancel on an un-run / missing result will result in nothing
                receiver_ref_p.cancel_receive(session_id, chunk_key2)

                # start creating writer
                receiver_ref_p.create_data_writer(session_id, chunk_key1, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(
                    self.get_result(5),
                    (receiver_ref.address, ReceiveStatus.RECEIVED))

                result = receiver_ref_p.create_data_writer(session_id,
                                                           chunk_key1,
                                                           data_size,
                                                           test_actor,
                                                           use_promise=False)
                self.assertTupleEqual(
                    result, (receiver_ref.address, ReceiveStatus.RECEIVED))

                receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(self.get_result(5),
                                      (receiver_ref.address, None))

                result = receiver_ref_p.create_data_writer(session_id,
                                                           chunk_key2,
                                                           data_size,
                                                           test_actor,
                                                           use_promise=False)
                self.assertTupleEqual(
                    result, (receiver_ref.address, ReceiveStatus.RECEIVING))

                receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(
                    self.get_result(5),
                    (receiver_ref.address, ReceiveStatus.RECEIVING))

                receiver_ref_p.cancel_receive(session_id, chunk_key2)
                self.assertEqual(
                    receiver_ref.check_status(session_id, chunk_key2),
                    ReceiveStatus.NOT_STARTED)

                # test checksum error on receive_data_part
                receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.get_result(5)

                receiver_ref_p.register_finish_callback(session_id, chunk_key2, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                receiver_ref_p.receive_data_part(session_id, chunk_key2,
                                                 serialized_mock_data, 0)

                with self.assertRaises(ChecksumMismatch):
                    self.get_result(5)

                # test checksum error on finish_receive
                receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(self.get_result(5),
                                      (receiver_ref.address, None))

                receiver_ref_p.receive_data_part(session_id, chunk_key2,
                                                 serialized_mock_data,
                                                 serialized_crc32)
                receiver_ref_p.finish_receive(session_id, chunk_key2, 0)

                receiver_ref_p.register_finish_callback(session_id, chunk_key2, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                with self.assertRaises(ChecksumMismatch):
                    self.get_result(5)

                receiver_ref_p.cancel_receive(session_id, chunk_key2)

                # test intermediate cancellation
                receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(self.get_result(5),
                                      (receiver_ref.address, None))

                receiver_ref_p.register_finish_callback(session_id, chunk_key2, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                receiver_ref_p.receive_data_part(
                    session_id, chunk_key2, serialized_mock_data[:64],
                    zlib.crc32(serialized_mock_data[:64]))
                receiver_ref_p.cancel_receive(session_id, chunk_key2)
                receiver_ref_p.receive_data_part(session_id, chunk_key2,
                                                 serialized_mock_data[64:],
                                                 serialized_crc32)
                with self.assertRaises(ExecutionInterrupted):
                    self.get_result(5)

                # test transfer in memory
                receiver_ref_p.register_finish_callback(session_id, chunk_key3, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                receiver_ref_p.create_data_writer(session_id, chunk_key3, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(self.get_result(5),
                                      (receiver_ref.address, None))

                receiver_ref_p.receive_data_part(
                    session_id, chunk_key3, serialized_mock_data[:64],
                    zlib.crc32(serialized_mock_data[:64]))
                receiver_ref_p.receive_data_part(session_id, chunk_key3,
                                                 serialized_mock_data[64:],
                                                 serialized_crc32)
                receiver_ref_p.finish_receive(session_id, chunk_key3,
                                              serialized_crc32)

                self.assertTupleEqual((), self.get_result(5))

                receiver_ref_p.create_data_writer(session_id, chunk_key3, data_size, test_actor, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(
                    self.get_result(5),
                    (receiver_ref.address, ReceiveStatus.RECEIVED))

                # test transfer in spill file
                def mocked_store_create(*_):
                    raise StorageFull

                with patch_method(PlasmaSharedStore.create,
                                  new=mocked_store_create):
                    with self.assertRaises(StorageFull):
                        receiver_ref_p.create_data_writer(session_id,
                                                          chunk_key4,
                                                          data_size,
                                                          test_actor,
                                                          ensure_cached=True,
                                                          use_promise=False)
                    # test receive aborted
                    receiver_ref_p.create_data_writer(
                        session_id, chunk_key4, data_size, test_actor, ensure_cached=False, _promise=True) \
                        .then(lambda *s: test_actor.set_result(s))
                    self.assertTupleEqual(self.get_result(5),
                                          (receiver_ref.address, None))

                    receiver_ref_p.register_finish_callback(session_id, chunk_key4, _promise=True) \
                        .then(lambda *s: test_actor.set_result(s)) \
                        .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                    receiver_ref_p.receive_data_part(
                        session_id, chunk_key4, serialized_mock_data[:64],
                        zlib.crc32(serialized_mock_data[:64]))
                    receiver_ref_p.cancel_receive(session_id, chunk_key4)
                    with self.assertRaises(ExecutionInterrupted):
                        self.get_result(5)

                    # test receive into spill
                    receiver_ref_p.create_data_writer(
                        session_id, chunk_key4, data_size, test_actor, ensure_cached=False, _promise=True) \
                        .then(lambda *s: test_actor.set_result(s))
                    self.assertTupleEqual(self.get_result(5),
                                          (receiver_ref.address, None))

                    receiver_ref_p.register_finish_callback(session_id, chunk_key4, _promise=True) \
                        .then(lambda *s: test_actor.set_result(s)) \
                        .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                    receiver_ref_p.receive_data_part(session_id, chunk_key4,
                                                     serialized_mock_data,
                                                     serialized_crc32)
                    receiver_ref_p.finish_receive(session_id, chunk_key4,
                                                  serialized_crc32)

                    self.assertTupleEqual((), self.get_result(5))

                # test intermediate error
                def mocked_store_create(*_):
                    raise SpillNotConfigured

                with patch_method(PlasmaSharedStore.create,
                                  new=mocked_store_create):
                    receiver_ref_p.create_data_writer(
                        session_id, chunk_key5, data_size, test_actor, ensure_cached=False, _promise=True) \
                        .then(lambda *s: test_actor.set_result(s),
                              lambda *s: test_actor.set_result(s, accept=False))

                    with self.assertRaises(SpillNotConfigured):
                        self.get_result(5)

                # test receive timeout
                receiver_ref_p.register_finish_callback(session_id, chunk_key6, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                receiver_ref_p.create_data_writer(session_id, chunk_key6, data_size, test_actor,
                                                  timeout=2, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(self.get_result(5),
                                      (receiver_ref.address, None))
                receiver_ref_p.receive_data_part(
                    session_id, chunk_key6, serialized_mock_data[:64],
                    zlib.crc32(serialized_mock_data[:64]))

                with self.assertRaises(TimeoutError):
                    self.get_result(5)

                # test sender halt
                receiver_ref_p.register_finish_callback(session_id, chunk_key7, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))

                mock_ref = pool.actor_ref(test_actor.uid, address='MOCK_ADDR')
                receiver_ref_p.create_data_writer(
                    session_id, chunk_key7, data_size, mock_ref, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s))
                self.assertTupleEqual(self.get_result(5),
                                      (receiver_ref.address, None))
                receiver_ref_p.receive_data_part(
                    session_id, chunk_key7, serialized_mock_data[:64],
                    zlib.crc32(serialized_mock_data[:64]))
                receiver_ref_p.notify_dead_senders(['MOCK_ADDR'])

                with self.assertRaises(WorkerDead):
                    self.get_result(5)

                # test checksum error on finish_receive
                result = receiver_ref_p.create_data_writer(session_id,
                                                           chunk_key8,
                                                           data_size,
                                                           test_actor,
                                                           use_promise=False)
                self.assertTupleEqual(result, (receiver_ref.address, None))

                receiver_ref_p.receive_data_part(session_id, chunk_key8,
                                                 serialized_mock_data,
                                                 serialized_crc32)
                receiver_ref_p.finish_receive(session_id, chunk_key8, 0)
Ejemplo n.º 7
0
    def testSender(self):
        send_pool_addr = 'localhost:%d' % get_next_port()
        recv_pool_addr = 'localhost:%d' % get_next_port()
        recv_pool_addr2 = 'localhost:%d' % get_next_port()

        options.worker.spill_directory = tempfile.mkdtemp(
            prefix='mars_test_sender_')
        session_id = str(uuid.uuid4())

        mock_data = np.array([1, 2, 3, 4])
        chunk_key1 = str(uuid.uuid4())
        chunk_key2 = str(uuid.uuid4())

        @contextlib.contextmanager
        def start_send_recv_pool():
            with start_transfer_test_pool(
                    address=send_pool_addr,
                    plasma_size=self.plasma_storage_size) as sp:
                sp.create_actor(SenderActor, uid=SenderActor.default_uid())
                with start_transfer_test_pool(
                        address=recv_pool_addr,
                        plasma_size=self.plasma_storage_size) as rp:
                    rp.create_actor(MockReceiverActor,
                                    uid=ReceiverActor.default_uid())
                    yield sp, rp

        with start_send_recv_pool() as (send_pool, recv_pool):
            sender_ref = send_pool.actor_ref(SenderActor.default_uid())
            receiver_ref = recv_pool.actor_ref(ReceiverActor.default_uid())

            with self.run_actor_test(send_pool) as test_actor:
                storage_client = test_actor.storage_client

                # send when data missing
                sender_ref_p = test_actor.promise_ref(sender_ref)
                sender_ref_p.send_data(session_id, str(uuid.uuid4()), recv_pool_addr, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))
                with self.assertRaises(DependencyMissing):
                    self.get_result(5)

                # send data in spill
                serialized = dataserializer.serialize(mock_data)
                self.waitp(
                    storage_client.create_writer(
                        session_id, chunk_key1, serialized.total_bytes,
                        [DataStorageDevice.DISK
                         ]).then(lambda writer: promise.finished().then(
                             lambda *_: writer.write(serialized)).then(
                                 lambda *_: writer.close())))

                sender_ref_p.send_data(session_id, chunk_key1, recv_pool_addr, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))
                self.get_result(5)
                assert_array_equal(
                    mock_data,
                    receiver_ref.get_result_data(session_id, chunk_key1))
                storage_client.delete(session_id, chunk_key1)

                # send data in plasma store
                self.waitp(
                    storage_client.put_object(
                        session_id, chunk_key1, mock_data,
                        [DataStorageDevice.SHARED_MEMORY]))

                sender_ref_p.send_data(session_id, chunk_key1, recv_pool_addr, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))
                self.get_result(5)
                assert_array_equal(
                    mock_data,
                    receiver_ref.get_result_data(session_id, chunk_key1))

                # send data to multiple targets
                with start_transfer_test_pool(
                        address=recv_pool_addr2,
                        plasma_size=self.plasma_storage_size) as rp2:
                    recv_ref2 = rp2.create_actor(
                        MockReceiverActor, uid=ReceiverActor.default_uid())

                    self.waitp(
                        sender_ref_p.send_data(
                            session_id,
                            chunk_key1, [recv_pool_addr, recv_pool_addr2],
                            _promise=True))
                    # send data to already transferred / transferring
                    sender_ref_p.send_data(session_id, chunk_key1,
                                           [recv_pool_addr, recv_pool_addr2], _promise=True) \
                        .then(lambda *s: test_actor.set_result(s)) \
                        .catch(lambda *exc: test_actor.set_result(exc, accept=False))
                    self.get_result(5)
                    assert_array_equal(
                        mock_data,
                        recv_ref2.get_result_data(session_id, chunk_key1))

                # send data to non-exist endpoint which causes error
                self.waitp(
                    storage_client.put_object(
                        session_id, chunk_key2, mock_data,
                        [DataStorageDevice.SHARED_MEMORY]))

                sender_ref_p.send_data(session_id, chunk_key2, recv_pool_addr2, _promise=True) \
                    .then(lambda *s: test_actor.set_result(s)) \
                    .catch(lambda *exc: test_actor.set_result(exc, accept=False))
                with self.assertRaises(BrokenPipeError):
                    self.get_result(5)

                def mocked_receive_data_part(*_):
                    raise ChecksumMismatch

                with patch_method(MockReceiverActor.receive_data_part,
                                  new=mocked_receive_data_part):
                    sender_ref_p.send_data(session_id, chunk_key2, recv_pool_addr, _promise=True) \
                        .then(lambda *s: test_actor.set_result(s)) \
                        .catch(lambda *exc: test_actor.set_result(exc, accept=False))
                    with self.assertRaises(ChecksumMismatch):
                        self.get_result(5)
Ejemplo n.º 8
0
    def testReceiverManager(self):
        pool_addr = f'localhost:{get_next_port()}'
        session_id = str(uuid.uuid4())

        mock_data = np.array([1, 2, 3, 4])
        serialized_data = dataserializer.dumps(mock_data)
        data_size = len(serialized_data)

        chunk_key1 = str(uuid.uuid4())
        chunk_key2 = str(uuid.uuid4())
        chunk_key3 = str(uuid.uuid4())
        chunk_key4 = str(uuid.uuid4())
        chunk_key5 = str(uuid.uuid4())
        chunk_key6 = str(uuid.uuid4())
        chunk_key7 = str(uuid.uuid4())

        with start_transfer_test_pool(address=pool_addr, plasma_size=self.plasma_storage_size) as pool, \
                self.run_actor_test(pool) as test_actor:
            mock_receiver_ref = pool.create_actor(MockReceiverWorkerActor,
                                                  uid=str(uuid.uuid4()))
            storage_client = test_actor.storage_client
            receiver_manager_ref = test_actor.promise_ref(
                ReceiverManagerActor.default_uid())

            # SCENARIO 1: test transferring existing keys
            self.waitp(
                storage_client.create_writer(
                    session_id, chunk_key1, data_size,
                    [DataStorageDevice.DISK
                     ]).then(lambda writer: promise.finished().then(
                         lambda *_: writer.write(serialized_data)).then(
                             lambda *_: writer.close())))
            result = self.waitp(
                receiver_manager_ref.create_data_writers(session_id,
                                                         [chunk_key1],
                                                         [data_size],
                                                         test_actor,
                                                         _promise=True))
            self.assertEqual(result[0].uid, mock_receiver_ref.uid)
            self.assertEqual(result[1][0], ReceiveStatus.RECEIVED)

            # test adding callback for transferred key (should return immediately)
            result = self.waitp(
                receiver_manager_ref.add_keys_callback(session_id,
                                                       [chunk_key1],
                                                       _promise=True))
            self.assertTupleEqual(result, ())

            receiver_manager_ref.register_pending_keys(
                session_id, [chunk_key1, chunk_key2])
            self.assertEqual(
                receiver_manager_ref.filter_receiving_keys(
                    session_id, [chunk_key1, chunk_key2, 'non_exist']),
                [chunk_key2])

            # SCENARIO 2: test transferring new keys and wait on listeners
            result = self.waitp(
                receiver_manager_ref.create_data_writers(
                    session_id, [chunk_key2, chunk_key3], [data_size] * 2,
                    test_actor,
                    _promise=True))
            self.assertEqual(result[0].uid, mock_receiver_ref.uid)
            self.assertIsNone(result[1][0])

            # transfer with transferring keys will report RECEIVING
            result = self.waitp(
                receiver_manager_ref.create_data_writers(session_id,
                                                         [chunk_key2],
                                                         [data_size],
                                                         test_actor,
                                                         _promise=True))
            self.assertEqual(result[1][0], ReceiveStatus.RECEIVING)

            # add listener and finish transfer
            receiver_manager_ref.add_keys_callback(session_id, [chunk_key1, chunk_key2], _promise=True) \
                .then(lambda *s: test_actor.set_result(s))
            mock_receiver_ref.receive_data_part(session_id, [chunk_key2],
                                                [True], serialized_data)
            mock_receiver_ref.receive_data_part(session_id, [chunk_key3],
                                                [True], serialized_data)
            self.get_result(5)

            # SCENARIO 3: test listening on multiple transfers
            receiver_manager_ref.create_data_writers(
                session_id, [chunk_key4, chunk_key5], [data_size] * 2, test_actor, _promise=True) \
                .then(lambda *s: test_actor.set_result(s))
            self.get_result(5)
            # add listener
            receiver_manager_ref.add_keys_callback(session_id, [chunk_key4, chunk_key5], _promise=True) \
                .then(lambda *s: test_actor.set_result(s))
            mock_receiver_ref.receive_data_part(session_id, [chunk_key4],
                                                [True], serialized_data)
            # when some chunks are not transferred, promise will not return
            with self.assertRaises(TimeoutError):
                self.get_result(0.5)
            mock_receiver_ref.receive_data_part(session_id, [chunk_key5],
                                                [True], serialized_data)
            self.get_result(5)

            # SCENARIO 4: test listening on transfer with errors
            self.waitp(
                receiver_manager_ref.create_data_writers(session_id,
                                                         [chunk_key6],
                                                         [data_size],
                                                         test_actor,
                                                         _promise=True))
            receiver_manager_ref.add_keys_callback(session_id, [chunk_key6], _promise=True) \
                .then(lambda *s: test_actor.set_result(s)) \
                .catch(lambda *exc: test_actor.set_result(exc, accept=False))
            mock_receiver_ref.cancel_receive(session_id, [chunk_key6])
            with self.assertRaises(ExecutionInterrupted):
                self.get_result(5)

            # SCENARIO 5: test creating writers without promise
            ref, statuses = receiver_manager_ref.create_data_writers(
                session_id, [chunk_key7], [data_size],
                test_actor,
                use_promise=False)
            self.assertIsNone(statuses[0])
            self.assertEqual(ref.uid, mock_receiver_ref.uid)

            # SCENARIO 6: test transferring lost keys
            storage_client.delete(session_id, [chunk_key1])

            result = self.waitp(
                receiver_manager_ref.create_data_writers(session_id,
                                                         [chunk_key1],
                                                         [data_size],
                                                         test_actor,
                                                         _promise=True))
            self.assertEqual(result[0].uid, mock_receiver_ref.uid)
            self.assertIsNone(result[1][0])

            # add listener and finish transfer
            receiver_manager_ref.add_keys_callback(session_id, [chunk_key1], _promise=True) \
                .then(lambda *s: test_actor.set_result(s))
            mock_receiver_ref.receive_data_part(session_id, [chunk_key1],
                                                [True], serialized_data)
            self.get_result(5)