Exemple #1
0
    def setUp(self):
        endpoint = '127.0.0.1:%d' % get_next_port()
        self.endpoint = endpoint
        self.pool = create_actor_pool(n_process=1, backend='gevent', address=endpoint)
        self.pool.create_actor(SchedulerClusterInfoActor, [endpoint],
                               uid=SchedulerClusterInfoActor.default_name())
        self.pool.create_actor(SessionManagerActor, uid=SessionManagerActor.default_name())
        self.pool.create_actor(ResourceActor, uid=ResourceActor.default_name())

        self.api = MarsAPI(endpoint)
Exemple #2
0
class Test(unittest.TestCase):
    def setUp(self):
        endpoint = '127.0.0.1:%d' % get_next_port()
        self.endpoint = endpoint
        self.pool = create_actor_pool(n_process=1, backend='gevent', address=endpoint)
        self.pool.create_actor(SchedulerClusterInfoActor, [endpoint],
                               uid=SchedulerClusterInfoActor.default_name())
        self.pool.create_actor(SessionManagerActor, uid=SessionManagerActor.default_name())
        self.pool.create_actor(ResourceActor, uid=ResourceActor.default_name())

        self.api = MarsAPI(endpoint)

    def tearDown(self):
        self.pool.stop()

    @patch_method(GraphActor.execute_graph)
    def testApi(self, *_):
        self.assertEqual(0, self.api.count_workers())

        session_id = 'mock_session_id'
        self.api.create_session(session_id)
        self.assertEqual(1, len(self.api.session_manager.get_sessions()))
        self.api.delete_session(session_id)
        self.assertEqual(0, len(self.api.session_manager.get_sessions()))
        self.api.create_session(session_id)

        serialized_graph = 'mock_serialized_graph'
        graph_key = 'mock_graph_key'
        targets = 'mock_targets'
        self.api.submit_graph(session_id, serialized_graph, graph_key, targets)
        graph_uid = GraphActor.gen_uid(session_id, graph_key)
        graph_ref = self.api.get_actor_ref(graph_uid)
        self.assertTrue(self.pool.has_actor(graph_ref))

        state = self.api.get_graph_state(session_id, graph_key)
        self.assertEqual(GraphState('preparing'), state)

        self.api.stop_graph(session_id, graph_key)
        state = self.api.get_graph_state(session_id, graph_key)
        self.assertEqual(GraphState('cancelled'), state)

        self.api.delete_graph(session_id, graph_key)
        self.assertFalse(self.pool.has_actor(graph_ref))

    @patch_method(GraphActor.get_tileable_chunk_indexes)
    @patch_method(ChunkMetaClient.batch_get_chunk_shape)
    def testGetTensorNsplits(self, *_):
        session_id = 'mock_session_id'
        graph_key = 'mock_graph_key'
        tensor_key = 'mock_tensor_key'
        serialized_graph = 'mock_serialized_graph'

        graph_uid = GraphActor.gen_uid(session_id, graph_key)
        self.pool.create_actor(GraphActor, session_id, serialized_graph, graph_key, uid=graph_uid)
        self.pool.create_actor(ChunkMetaActor, uid=ChunkMetaActor.default_name())

        mock_indexes = [
            OrderedDict(zip(['chunk_key1', 'chunk_key2', 'chunk_key3', 'chunk_key4'],
                            [(0, ), (1,), (2,), (3,)])),
            OrderedDict(zip(['chunk_key1', 'chunk_key2', 'chunk_key3', 'chunk_key4'],
                            [(0, 0), (0, 1), (1, 0), (1, 1)]))
        ]
        mock_shapes = [
            [(3,), (4,), (5,), (6,)],
            [(3, 4), (3, 2), (2, 4), (2, 2)]
        ]

        GraphActor.get_tileable_chunk_indexes.side_effect = mock_indexes
        ChunkMetaClient.batch_get_chunk_shape.side_effect = mock_shapes

        nsplits = self.api.get_tileable_nsplits(session_id, graph_key, tensor_key)
        self.assertEqual(((3, 4, 5, 6),), nsplits)

        nsplits = self.api.get_tileable_nsplits(session_id, graph_key, tensor_key)
        self.assertEqual(((3, 2), (4, 2)), nsplits)