コード例 #1
0
ファイル: dataset.py プロジェクト: jwzxgy2007/jittor
    def __init__(self, target, args, buffer_size):
        self.buffer = jt.RingBuffer(buffer_size)

        self.status = mp.Array('f', 5, lock=False)
        self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
        self.p.daemon = True
        self.p.start()
コード例 #2
0
 def _init_workers(self, index_list):
     jt.clean()
     jt.gc()
     self.index_list = mp.Array('i', self.real_len, lock=False)
     workers = []
     # get worker id
     self.idqueue = jt.RingBuffer(2048)
     self.idqueue_lock = mp.Lock()
     # global token index
     self.gid = mp.Value('i', self.batch_len)
     self.gid.value = 0
     # global token index condition
     self.gidc = mp.Condition(self.gid.get_lock())
     # number of idle workers
     self.num_idle = mp.Value('i', 0, lock=False)
     # number of idle workers condition
     self.num_idle_c = mp.Condition(self.gid.get_lock())
     self.index_list_numpy = np.ndarray(dtype='int32',
                                        shape=self.real_len,
                                        buffer=self.index_list)
     self.index_list_numpy[:] = index_list
     for i in range(self.num_workers):
         w = Worker(target=self._worker_main,
                    args=(i, ),
                    buffer_size=self.buffer_size,
                    keep_numpy_array=self.keep_numpy_array)
         workers.append(w)
     self.workers = workers
コード例 #3
0
def test_ring_buffer():
    buffer = jt.RingBuffer(2000)

    def test_send_recv(data):
        print("test send recv", type(data))
        buffer.push(data)
        recv = buffer.pop()
        if isinstance(data, (np.ndarray, jt.Var)):
            assert (recv == data).all()
        else:
            assert data == recv

    n_byte = 0
    test_send_recv(1)
    n_byte += 1 + 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
    test_send_recv(100000000000)
    n_byte += 1 + 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()

    test_send_recv(1e-5)
    n_byte += 1 + 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
    test_send_recv(100000000000.0)
    n_byte += 1 + 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()

    test_send_recv("float32")
    n_byte += 1 + 8 + 7
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
    test_send_recv("")
    n_byte += 1 + 8 + 0
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
    test_send_recv("xxxxxxxxxx")
    n_byte += 1 + 8 + 10
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()

    test_send_recv([1, 0.2])
    n_byte += 1 + 8 + 1 + 8 + 1 + 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
    test_send_recv({'asd': 1})
    n_byte += 1 + 8 + 1 + 8 + 3 + 1 + 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()

    test_send_recv(np.random.rand(10, 10))
    n_byte += 1 + 16 + 2 + 10 * 10 * 8
    assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
    test_send_recv(test_ring_buffer)

    test_send_recv(jt.array(np.random.rand(10, 10)))

    bbox = BBox(jt.array(np.random.rand(10, 10)))
    test_send_recv(bbox)

    expect_error(lambda: test_send_recv(np.random.rand(10, 1000)))