Exemplo n.º 1
0
    def testDistributedContext(self):
        self.start_processes(etcd=False)
        sess = new_session(self.session_manager_ref.address)
        rs = np.random.RandomState(0)
        context = DistributedContext(
            scheduler_address=self.session_manager_ref.address,
            session_id=sess.session_id)

        raw1 = rs.rand(10, 10)
        a = mt.tensor(raw1, chunk_size=4)
        a.execute(session=sess, timeout=self.timeout, name='test')

        tileable_infos = context.get_named_tileable_infos('test')
        self.assertEqual(a.key, tileable_infos.tileable_key)
        self.assertEqual(a.shape, tileable_infos.tileable_shape)

        nsplits = context.get_tileable_metas([a.key],
                                             filter_fields=['nsplits'])[0][0]
        self.assertEqual(((4, 4, 2), (4, 4, 2)), nsplits)

        r = context.get_tileable_data(a.key)
        np.testing.assert_array_equal(raw1, r)

        indexes = [slice(3, 9), slice(0, 7)]
        r = context.get_tileable_data(a.key, indexes)
        np.testing.assert_array_equal(raw1[tuple(indexes)], r)

        indexes = [[1, 4, 2, 4, 5], slice(None, None, None)]
        r = context.get_tileable_data(a.key, indexes)
        np.testing.assert_array_equal(raw1[tuple(indexes)], r)

        indexes = ([9, 1, 2, 0], [0, 0, 4, 4])
        r = context.get_tileable_data(a.key, indexes)
        np.testing.assert_array_equal(raw1[[9, 1, 2, 0], [0, 0, 4, 4]], r)
Exemplo n.º 2
0
    def testDistributedContext(self):
        self.start_processes(etcd=False)

        session_id = uuid.uuid1()
        actor_client = new_client()
        rs = np.random.RandomState(0)

        context = DistributedContext(
            scheduler_address=self.scheduler_endpoints[0],
            session_id=session_id)

        session_ref = actor_client.actor_ref(
            self.session_manager_ref.create_session(session_id))
        raw1 = rs.rand(10, 10)
        a = mt.tensor(raw1, chunk_size=4)

        graph = a.build_graph()
        targets = [a.key]
        graph_key = uuid.uuid1()
        session_ref.submit_tileable_graph(json.dumps(graph.to_json()),
                                          graph_key,
                                          target_tileables=targets,
                                          names=['test'])

        state = self.wait_for_termination(actor_client, session_ref, graph_key)
        self.assertEqual(state, GraphState.SUCCEEDED)

        tileable_infos = context.get_named_tileable_infos('test')
        self.assertEqual(a.key, tileable_infos.tileable_key)
        self.assertEqual(a.shape, tileable_infos.tileable_shape)

        nsplits = context.get_tileable_metas([a.key],
                                             filter_fields=['nsplits'])[0][0]
        self.assertEqual(((4, 4, 2), (4, 4, 2)), nsplits)

        r = context.get_tileable_data(a.key)
        np.testing.assert_array_equal(raw1, r)

        indexes = [slice(3, 9), slice(0, 7)]
        r = context.get_tileable_data(a.key, indexes)
        np.testing.assert_array_equal(raw1[tuple(indexes)], r)

        indexes = [[1, 4, 2, 4, 5], slice(None, None, None)]
        r = context.get_tileable_data(a.key, indexes)
        np.testing.assert_array_equal(raw1[tuple(indexes)], r)

        indexes = ([9, 1, 2, 0], [0, 0, 4, 4])
        r = context.get_tileable_data(a.key, indexes)
        np.testing.assert_array_equal(raw1[[9, 1, 2, 0], [0, 0, 4, 4]], r)