def _create_writer_with_fail(self, session_id, chunk_key, *args, **kwargs): if chunk_key == fail_key: if kwargs.get('_promise', True): return promise.finished(*build_exc_info(ValueError), **dict(_accept=False)) else: raise ValueError return old_create_writer(self, session_id, chunk_key, *args, **kwargs)
def run_test_cache(self): session_id = str(uuid.uuid4()) chunk_holder_ref = self.promise_ref(ChunkHolderActor.default_name()) chunk_store = self._chunk_store data_list = [] for _ in range(9): data_id = str(uuid.uuid4()) data = np.random.randint(0, 32767, (655360, ), np.int16) data_list.append((data_id, data)) def _put_chunk(data_key, data, *_): def _handle_reject(*exc): if issubclass(exc[0], NoDataToSpill): return six.reraise(*exc) try: ref = chunk_store.put(session_id, data_key, data) chunk_holder_ref.register_chunk(session_id, data_key) self.ctx.sleep(0.5) del ref except StoreFull: return chunk_holder_ref.spill_size(calc_data_size(data) * 2, _promise=True) \ .then(partial(_put_chunk, data_key, data), _handle_reject) data_promises = [] for data_id, data in data_list: data_promises.append(promise.finished().then( partial(_put_chunk, data_id, data))) def assert_true(v): assert v last_id = data_list[-1][0] p = promise.all_(data_promises) \ .then(lambda *_: assert_true(chunk_store.contains(session_id, last_id))) \ .then(lambda *_: ensure_chunk(self, session_id, last_id)) \ .then(lambda *_: assert_true(chunk_store.contains(session_id, last_id))) first_id = data_list[0][0] p = p.then(lambda *_: assert_true(not chunk_store.contains(session_id, first_id))) \ .then(lambda *_: ensure_chunk(self, session_id, first_id)) \ .then(lambda *_: assert_true(chunk_store.contains(session_id, first_id))) p = p.then(lambda *_: chunk_holder_ref.unregister_chunk(session_id, first_id)) \ .then(lambda *_: self._plasma_client.evict(128)) \ .then(lambda *_: assert_true(not chunk_store.contains(session_id, first_id))) p = p.then(lambda *_: chunk_holder_ref.unregister_chunk(session_id, last_id)) \ .then(lambda *_: self._plasma_client.evict(128)) \ .then(lambda *_: assert_true(not chunk_store.contains(session_id, last_id))) p.catch(lambda *exc: setattr(self, '_exc_info', exc)) \ .then(lambda *_: setattr(self, '_finished', True))
def _mock_load_from(*_, **__): return promise.finished(*build_exc_info(SystemError), _accept=False)
def testDispatch(self, *_): call_records = dict() group_size = 4 mock_scheduler_addr = f'127.0.0.1:{get_next_port()}' with create_actor_pool(n_process=1, backend='gevent', address=mock_scheduler_addr) as pool: dispatch_ref = pool.create_actor(DispatchActor, uid=DispatchActor.default_uid()) # actors of g1 [ pool.create_actor(TaskActor, 'g1', call_records) for _ in range(group_size) ] [ pool.create_actor(TaskActor, 'g2', call_records) for _ in range(group_size) ] self.assertEqual(len(dispatch_ref.get_slots('g1')), group_size) self.assertEqual(len(dispatch_ref.get_slots('g2')), group_size) self.assertEqual(len(dispatch_ref.get_slots('g3')), 0) self.assertEqual(dispatch_ref.get_hash_slot('g1', 'hash_str'), dispatch_ref.get_hash_slot('g1', 'hash_str')) dispatch_ref.acquire_free_slot('g1', callback=(('NonExist', mock_scheduler_addr), '_non_exist', {})) self.assertEqual(dispatch_ref.get_free_slots_num().get('g1'), group_size) # tasks within [0, group_size - 1] will run almost simultaneously, # while the last one will be delayed due to lack of slots delay = 1 with self.run_actor_test(pool) as test_actor: p = promise.finished() _dispatch_ref = test_actor.promise_ref( DispatchActor.default_uid()) def _call_on_dispatched(uid, key=None): if uid is None: call_records[key] = 'NoneUID' else: test_actor.promise_ref(uid).queued_call(key, delay, _tell=True, _wait=False) for idx in range(group_size + 1): p = p.then(lambda *_: _dispatch_ref.acquire_free_slot('g1', _promise=True)) \ .then(partial(_call_on_dispatched, key=f'{idx}_1')) \ .then(lambda *_: _dispatch_ref.acquire_free_slot('g2', _promise=True)) \ .then(partial(_call_on_dispatched, key=f'{idx}_2')) p.then(lambda *_: _dispatch_ref.acquire_free_slot('g3', _promise=True)) \ .then(partial(_call_on_dispatched, key='N_1')) \ .then(lambda *_: test_actor.set_result(None)) self.get_result(20) self.assertEqual(call_records['N_1'], 'NoneUID') self.assertLess( sum( abs(call_records[f'{idx}_1'] - call_records['0_1']) for idx in range(group_size)), delay * 0.5) self.assertGreater( call_records[f'{group_size}_1'] - call_records['0_1'], delay * 0.5) self.assertLess( call_records[f'{group_size}_1'] - call_records['0_1'], delay * 1.5) dispatch_ref.destroy()
def testPromise(self): promises = weakref.WeakValueDictionary() req_queue = Queue() value_list = [] time_unit = 0.1 def test_thread_body(): while True: idx, v, success = req_queue.get() if v is None: break value_list.append(('thread_body', v)) time.sleep(time_unit) promises[idx].step_next([(v, ), dict(_accept=success)]) try: thread = threading.Thread(target=test_thread_body) thread.daemon = True thread.start() def gen_promise(value, accept=True): value_list.append(('gen_promise', value)) p = promise.Promise() promises[p.id] = p req_queue.put((p.id, value + 1, accept)) return p # simple promise call value_list = [] p = gen_promise(0) \ .then(lambda v: gen_promise(v)) \ .then(lambda v: gen_promise(v)) p.wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 2), ('thread_body', 3)]) # continue accepted call with then value_list = [] p.then(lambda *_: gen_promise(0)) \ .then(lambda v: gen_promise(v)) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2)]) # immediate error value_list = [] p = promise.finished() \ .then(lambda *_: 5 / 0) p.catch(lambda *_: gen_promise(0)) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1)]) # chained errors value_list = [] p = promise.finished(_accept=False) \ .catch(lambda *_: 1 / 0) \ .catch(lambda *_: 2 / 0) \ .catch(lambda *_: gen_promise(0)) \ .catch(lambda *_: gen_promise(1)) p.wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1)]) # continue error call value_list = [] p = gen_promise(0) \ .then(lambda *_: 5 / 0) \ .then(lambda *_: gen_promise(2)) time.sleep(0.5) value_list = [] p.catch(lambda *_: gen_promise(0)) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v).then(lambda x: x + 1)) \ .then(lambda v: gen_promise(v)) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 3), ('thread_body', 4)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v)) \ .then(lambda v: gen_promise(v, False)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 2), ('thread_body', 3), ('catch', 3)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v, False).then(lambda x: x + 1)) \ .then(lambda v: gen_promise(v)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('catch', 2)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v)) \ .then(lambda v: v + 1) \ .then(lambda v: gen_promise(v, False)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 3), ('thread_body', 4), ('catch', 4)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v)) \ .catch(lambda v: gen_promise(v)) \ .catch(lambda v: gen_promise(v)) \ .then(lambda v: gen_promise(v, False)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 2), ('thread_body', 3), ('catch', 3)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v)) \ .catch(lambda v: gen_promise(v)) \ .catch(lambda v: gen_promise(v)) \ .then(lambda v: gen_promise(v, False)) \ .then(lambda v: gen_promise(v), lambda v: gen_promise(v + 1, False)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 2), ('thread_body', 3), ('gen_promise', 4), ('thread_body', 5), ('catch', 5)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v)) \ .catch(lambda v: gen_promise(v)) \ .catch(lambda v: gen_promise(v)) \ .then(lambda v: gen_promise(v, False)) \ .then(lambda v: gen_promise(v), lambda v: _raise_exception(ValueError)) \ .catch(lambda *_: value_list.append(('catch',))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 2), ('thread_body', 3), ('catch', )]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v, False)) \ .catch(lambda v: gen_promise(v, False)) \ .catch(lambda v: gen_promise(v)) \ .then(lambda v: gen_promise(v)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('gen_promise', 2), ('thread_body', 3), ('gen_promise', 3), ('thread_body', 4), ('gen_promise', 4), ('thread_body', 5)]) value_list = [] gen_promise(0) \ .then(lambda v: gen_promise(v, False)) \ .then(lambda v: gen_promise(v)) \ .catch(lambda v: value_list.append(('catch', v))) \ .wait() self.assertListEqual(value_list, [('gen_promise', 0), ('thread_body', 1), ('gen_promise', 1), ('thread_body', 2), ('catch', 2)]) finally: self.assertDictEqual(promise._promise_pool, {}) req_queue.put((None, None, None))
def testReceiver(self): pool_addr = 'localhost:%d' % get_next_port() options.worker.spill_directory = tempfile.mkdtemp( prefix='mars_test_receiver_') session_id = str(uuid.uuid4()) mock_data = np.array([1, 2, 3, 4]) serialized_arrow_data = dataserializer.serialize(mock_data) data_size = serialized_arrow_data.total_bytes serialized_mock_data = serialized_arrow_data.to_buffer() serialized_crc32 = zlib.crc32(serialized_arrow_data.to_buffer()) chunk_key1 = str(uuid.uuid4()) chunk_key2 = str(uuid.uuid4()) chunk_key3 = str(uuid.uuid4()) chunk_key4 = str(uuid.uuid4()) chunk_key5 = str(uuid.uuid4()) chunk_key6 = str(uuid.uuid4()) chunk_key7 = str(uuid.uuid4()) chunk_key8 = str(uuid.uuid4()) with start_transfer_test_pool( address=pool_addr, plasma_size=self.plasma_storage_size) as pool: receiver_ref = pool.create_actor(ReceiverActor, uid=str(uuid.uuid4())) with self.run_actor_test(pool) as test_actor: storage_client = test_actor.storage_client # check_status on receiving and received self.assertEqual( receiver_ref.check_status(session_id, chunk_key1), ReceiveStatus.NOT_STARTED) self.waitp( storage_client.create_writer( session_id, chunk_key1, serialized_arrow_data.total_bytes, [DataStorageDevice.DISK ]).then(lambda writer: promise.finished().then( lambda *_: writer.write(serialized_arrow_data)). then(lambda *_: writer.close()))) self.assertEqual( receiver_ref.check_status(session_id, chunk_key1), ReceiveStatus.RECEIVED) storage_client.delete(session_id, chunk_key1) self.waitp( storage_client.put_object( session_id, chunk_key1, mock_data, [DataStorageDevice.SHARED_MEMORY])) self.assertEqual( receiver_ref.check_status(session_id, chunk_key1), ReceiveStatus.RECEIVED) receiver_ref_p = test_actor.promise_ref(receiver_ref) # cancel on an un-run / missing result will result in nothing receiver_ref_p.cancel_receive(session_id, chunk_key2) # start creating writer receiver_ref_p.create_data_writer(session_id, chunk_key1, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual( self.get_result(5), (receiver_ref.address, ReceiveStatus.RECEIVED)) result = receiver_ref_p.create_data_writer(session_id, chunk_key1, data_size, test_actor, use_promise=False) self.assertTupleEqual( result, (receiver_ref.address, ReceiveStatus.RECEIVED)) receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) result = receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, use_promise=False) self.assertTupleEqual( result, (receiver_ref.address, ReceiveStatus.RECEIVING)) receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual( self.get_result(5), (receiver_ref.address, ReceiveStatus.RECEIVING)) receiver_ref_p.cancel_receive(session_id, chunk_key2) self.assertEqual( receiver_ref.check_status(session_id, chunk_key2), ReceiveStatus.NOT_STARTED) # test checksum error on receive_data_part receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.get_result(5) receiver_ref_p.register_finish_callback(session_id, chunk_key2, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) receiver_ref_p.receive_data_part(session_id, chunk_key2, serialized_mock_data, 0) with self.assertRaises(ChecksumMismatch): self.get_result(5) # test checksum error on finish_receive receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.receive_data_part(session_id, chunk_key2, serialized_mock_data, serialized_crc32) receiver_ref_p.finish_receive(session_id, chunk_key2, 0) receiver_ref_p.register_finish_callback(session_id, chunk_key2, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) with self.assertRaises(ChecksumMismatch): self.get_result(5) receiver_ref_p.cancel_receive(session_id, chunk_key2) # test intermediate cancellation receiver_ref_p.create_data_writer(session_id, chunk_key2, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.register_finish_callback(session_id, chunk_key2, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) receiver_ref_p.receive_data_part( session_id, chunk_key2, serialized_mock_data[:64], zlib.crc32(serialized_mock_data[:64])) receiver_ref_p.cancel_receive(session_id, chunk_key2) receiver_ref_p.receive_data_part(session_id, chunk_key2, serialized_mock_data[64:], serialized_crc32) with self.assertRaises(ExecutionInterrupted): self.get_result(5) # test transfer in memory receiver_ref_p.register_finish_callback(session_id, chunk_key3, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) receiver_ref_p.create_data_writer(session_id, chunk_key3, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.receive_data_part( session_id, chunk_key3, serialized_mock_data[:64], zlib.crc32(serialized_mock_data[:64])) receiver_ref_p.receive_data_part(session_id, chunk_key3, serialized_mock_data[64:], serialized_crc32) receiver_ref_p.finish_receive(session_id, chunk_key3, serialized_crc32) self.assertTupleEqual((), self.get_result(5)) receiver_ref_p.create_data_writer(session_id, chunk_key3, data_size, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual( self.get_result(5), (receiver_ref.address, ReceiveStatus.RECEIVED)) # test transfer in spill file def mocked_store_create(*_): raise StorageFull with patch_method(PlasmaSharedStore.create, new=mocked_store_create): with self.assertRaises(StorageFull): receiver_ref_p.create_data_writer(session_id, chunk_key4, data_size, test_actor, ensure_cached=True, use_promise=False) # test receive aborted receiver_ref_p.create_data_writer( session_id, chunk_key4, data_size, test_actor, ensure_cached=False, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.register_finish_callback(session_id, chunk_key4, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) receiver_ref_p.receive_data_part( session_id, chunk_key4, serialized_mock_data[:64], zlib.crc32(serialized_mock_data[:64])) receiver_ref_p.cancel_receive(session_id, chunk_key4) with self.assertRaises(ExecutionInterrupted): self.get_result(5) # test receive into spill receiver_ref_p.create_data_writer( session_id, chunk_key4, data_size, test_actor, ensure_cached=False, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.register_finish_callback(session_id, chunk_key4, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) receiver_ref_p.receive_data_part(session_id, chunk_key4, serialized_mock_data, serialized_crc32) receiver_ref_p.finish_receive(session_id, chunk_key4, serialized_crc32) self.assertTupleEqual((), self.get_result(5)) # test intermediate error def mocked_store_create(*_): raise SpillNotConfigured with patch_method(PlasmaSharedStore.create, new=mocked_store_create): receiver_ref_p.create_data_writer( session_id, chunk_key5, data_size, test_actor, ensure_cached=False, _promise=True) \ .then(lambda *s: test_actor.set_result(s), lambda *s: test_actor.set_result(s, accept=False)) with self.assertRaises(SpillNotConfigured): self.get_result(5) # test receive timeout receiver_ref_p.register_finish_callback(session_id, chunk_key6, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) receiver_ref_p.create_data_writer(session_id, chunk_key6, data_size, test_actor, timeout=2, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.receive_data_part( session_id, chunk_key6, serialized_mock_data[:64], zlib.crc32(serialized_mock_data[:64])) with self.assertRaises(TimeoutError): self.get_result(5) # test sender halt receiver_ref_p.register_finish_callback(session_id, chunk_key7, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) mock_ref = pool.actor_ref(test_actor.uid, address='MOCK_ADDR') receiver_ref_p.create_data_writer( session_id, chunk_key7, data_size, mock_ref, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.assertTupleEqual(self.get_result(5), (receiver_ref.address, None)) receiver_ref_p.receive_data_part( session_id, chunk_key7, serialized_mock_data[:64], zlib.crc32(serialized_mock_data[:64])) receiver_ref_p.notify_dead_senders(['MOCK_ADDR']) with self.assertRaises(WorkerDead): self.get_result(5) # test checksum error on finish_receive result = receiver_ref_p.create_data_writer(session_id, chunk_key8, data_size, test_actor, use_promise=False) self.assertTupleEqual(result, (receiver_ref.address, None)) receiver_ref_p.receive_data_part(session_id, chunk_key8, serialized_mock_data, serialized_crc32) receiver_ref_p.finish_receive(session_id, chunk_key8, 0)
def testSender(self): send_pool_addr = 'localhost:%d' % get_next_port() recv_pool_addr = 'localhost:%d' % get_next_port() recv_pool_addr2 = 'localhost:%d' % get_next_port() options.worker.spill_directory = tempfile.mkdtemp( prefix='mars_test_sender_') session_id = str(uuid.uuid4()) mock_data = np.array([1, 2, 3, 4]) chunk_key1 = str(uuid.uuid4()) chunk_key2 = str(uuid.uuid4()) @contextlib.contextmanager def start_send_recv_pool(): with start_transfer_test_pool( address=send_pool_addr, plasma_size=self.plasma_storage_size) as sp: sp.create_actor(SenderActor, uid=SenderActor.default_uid()) with start_transfer_test_pool( address=recv_pool_addr, plasma_size=self.plasma_storage_size) as rp: rp.create_actor(MockReceiverActor, uid=ReceiverActor.default_uid()) yield sp, rp with start_send_recv_pool() as (send_pool, recv_pool): sender_ref = send_pool.actor_ref(SenderActor.default_uid()) receiver_ref = recv_pool.actor_ref(ReceiverActor.default_uid()) with self.run_actor_test(send_pool) as test_actor: storage_client = test_actor.storage_client # send when data missing sender_ref_p = test_actor.promise_ref(sender_ref) sender_ref_p.send_data(session_id, str(uuid.uuid4()), recv_pool_addr, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) with self.assertRaises(DependencyMissing): self.get_result(5) # send data in spill serialized = dataserializer.serialize(mock_data) self.waitp( storage_client.create_writer( session_id, chunk_key1, serialized.total_bytes, [DataStorageDevice.DISK ]).then(lambda writer: promise.finished().then( lambda *_: writer.write(serialized)).then( lambda *_: writer.close()))) sender_ref_p.send_data(session_id, chunk_key1, recv_pool_addr, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) self.get_result(5) assert_array_equal( mock_data, receiver_ref.get_result_data(session_id, chunk_key1)) storage_client.delete(session_id, chunk_key1) # send data in plasma store self.waitp( storage_client.put_object( session_id, chunk_key1, mock_data, [DataStorageDevice.SHARED_MEMORY])) sender_ref_p.send_data(session_id, chunk_key1, recv_pool_addr, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) self.get_result(5) assert_array_equal( mock_data, receiver_ref.get_result_data(session_id, chunk_key1)) # send data to multiple targets with start_transfer_test_pool( address=recv_pool_addr2, plasma_size=self.plasma_storage_size) as rp2: recv_ref2 = rp2.create_actor( MockReceiverActor, uid=ReceiverActor.default_uid()) self.waitp( sender_ref_p.send_data( session_id, chunk_key1, [recv_pool_addr, recv_pool_addr2], _promise=True)) # send data to already transferred / transferring sender_ref_p.send_data(session_id, chunk_key1, [recv_pool_addr, recv_pool_addr2], _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) self.get_result(5) assert_array_equal( mock_data, recv_ref2.get_result_data(session_id, chunk_key1)) # send data to non-exist endpoint which causes error self.waitp( storage_client.put_object( session_id, chunk_key2, mock_data, [DataStorageDevice.SHARED_MEMORY])) sender_ref_p.send_data(session_id, chunk_key2, recv_pool_addr2, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) with self.assertRaises(BrokenPipeError): self.get_result(5) def mocked_receive_data_part(*_): raise ChecksumMismatch with patch_method(MockReceiverActor.receive_data_part, new=mocked_receive_data_part): sender_ref_p.send_data(session_id, chunk_key2, recv_pool_addr, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) with self.assertRaises(ChecksumMismatch): self.get_result(5)
def testReceiverManager(self): pool_addr = f'localhost:{get_next_port()}' session_id = str(uuid.uuid4()) mock_data = np.array([1, 2, 3, 4]) serialized_data = dataserializer.dumps(mock_data) data_size = len(serialized_data) chunk_key1 = str(uuid.uuid4()) chunk_key2 = str(uuid.uuid4()) chunk_key3 = str(uuid.uuid4()) chunk_key4 = str(uuid.uuid4()) chunk_key5 = str(uuid.uuid4()) chunk_key6 = str(uuid.uuid4()) chunk_key7 = str(uuid.uuid4()) with start_transfer_test_pool(address=pool_addr, plasma_size=self.plasma_storage_size) as pool, \ self.run_actor_test(pool) as test_actor: mock_receiver_ref = pool.create_actor(MockReceiverWorkerActor, uid=str(uuid.uuid4())) storage_client = test_actor.storage_client receiver_manager_ref = test_actor.promise_ref( ReceiverManagerActor.default_uid()) # SCENARIO 1: test transferring existing keys self.waitp( storage_client.create_writer( session_id, chunk_key1, data_size, [DataStorageDevice.DISK ]).then(lambda writer: promise.finished().then( lambda *_: writer.write(serialized_data)).then( lambda *_: writer.close()))) result = self.waitp( receiver_manager_ref.create_data_writers(session_id, [chunk_key1], [data_size], test_actor, _promise=True)) self.assertEqual(result[0].uid, mock_receiver_ref.uid) self.assertEqual(result[1][0], ReceiveStatus.RECEIVED) # test adding callback for transferred key (should return immediately) result = self.waitp( receiver_manager_ref.add_keys_callback(session_id, [chunk_key1], _promise=True)) self.assertTupleEqual(result, ()) receiver_manager_ref.register_pending_keys( session_id, [chunk_key1, chunk_key2]) self.assertEqual( receiver_manager_ref.filter_receiving_keys( session_id, [chunk_key1, chunk_key2, 'non_exist']), [chunk_key2]) # SCENARIO 2: test transferring new keys and wait on listeners result = self.waitp( receiver_manager_ref.create_data_writers( session_id, [chunk_key2, chunk_key3], [data_size] * 2, test_actor, _promise=True)) self.assertEqual(result[0].uid, mock_receiver_ref.uid) self.assertIsNone(result[1][0]) # transfer with transferring keys will report RECEIVING result = self.waitp( receiver_manager_ref.create_data_writers(session_id, [chunk_key2], [data_size], test_actor, _promise=True)) self.assertEqual(result[1][0], ReceiveStatus.RECEIVING) # add listener and finish transfer receiver_manager_ref.add_keys_callback(session_id, [chunk_key1, chunk_key2], _promise=True) \ .then(lambda *s: test_actor.set_result(s)) mock_receiver_ref.receive_data_part(session_id, [chunk_key2], [True], serialized_data) mock_receiver_ref.receive_data_part(session_id, [chunk_key3], [True], serialized_data) self.get_result(5) # SCENARIO 3: test listening on multiple transfers receiver_manager_ref.create_data_writers( session_id, [chunk_key4, chunk_key5], [data_size] * 2, test_actor, _promise=True) \ .then(lambda *s: test_actor.set_result(s)) self.get_result(5) # add listener receiver_manager_ref.add_keys_callback(session_id, [chunk_key4, chunk_key5], _promise=True) \ .then(lambda *s: test_actor.set_result(s)) mock_receiver_ref.receive_data_part(session_id, [chunk_key4], [True], serialized_data) # when some chunks are not transferred, promise will not return with self.assertRaises(TimeoutError): self.get_result(0.5) mock_receiver_ref.receive_data_part(session_id, [chunk_key5], [True], serialized_data) self.get_result(5) # SCENARIO 4: test listening on transfer with errors self.waitp( receiver_manager_ref.create_data_writers(session_id, [chunk_key6], [data_size], test_actor, _promise=True)) receiver_manager_ref.add_keys_callback(session_id, [chunk_key6], _promise=True) \ .then(lambda *s: test_actor.set_result(s)) \ .catch(lambda *exc: test_actor.set_result(exc, accept=False)) mock_receiver_ref.cancel_receive(session_id, [chunk_key6]) with self.assertRaises(ExecutionInterrupted): self.get_result(5) # SCENARIO 5: test creating writers without promise ref, statuses = receiver_manager_ref.create_data_writers( session_id, [chunk_key7], [data_size], test_actor, use_promise=False) self.assertIsNone(statuses[0]) self.assertEqual(ref.uid, mock_receiver_ref.uid) # SCENARIO 6: test transferring lost keys storage_client.delete(session_id, [chunk_key1]) result = self.waitp( receiver_manager_ref.create_data_writers(session_id, [chunk_key1], [data_size], test_actor, _promise=True)) self.assertEqual(result[0].uid, mock_receiver_ref.uid) self.assertIsNone(result[1][0]) # add listener and finish transfer receiver_manager_ref.add_keys_callback(session_id, [chunk_key1], _promise=True) \ .then(lambda *s: test_actor.set_result(s)) mock_receiver_ref.receive_data_part(session_id, [chunk_key1], [True], serialized_data) self.get_result(5)