コード例 #1
0
ファイル: train_network.py プロジェクト: eric-clinch/GOat
def GetDataset(replay_memory: ReplayMemory, mem_lock: Lock,
               batchsize: int) -> DataLoader:
    datapoints = []
    mem_lock.acquire()
    for move_datapoint in replay_memory.memory:
        player = move_datapoint.player
        player_stone_map = GetPlayerStoneMap(move_datapoint.board_list, player)
        opponent_stone_map = GetPlayerStoneMap(move_datapoint.board_list,
                                               1 - player)
        network_input = torch.Tensor([player_stone_map, opponent_stone_map])
        value = torch.Tensor([move_datapoint.confidence])
        policy = torch.Tensor(move_datapoint.policy)
        datapoints.append((network_input, value, policy))
    mem_lock.release()
    return DataLoader(datapoints, batch_size=batchsize, shuffle=True)
コード例 #2
0
ファイル: train_network.py プロジェクト: eric-clinch/GOat
def ReceivePlayouts(worker: socket.socket, worker_id: int,
                    replay_memory: ReplayMemory, mem_lock: Lock):
    worker.setblocking(True)
    while True:
        try:
            msg = communication.Receive(worker)
        except Exception as err:
            print(f"Error with worker {worker_id}, ending connection")
            worker.close()
            return

        playout: List[MoveDatapoint] = pickle.loads(msg)

        mem_lock.acquire()
        for move_datapoint in playout:
            replay_memory.push(move_datapoint)
        mem_lock.release()
        print(f"{len(playout)} new datapoints added from worker {worker_id}")
コード例 #3
0
def main():
    args = parse_args()
    categories = parse_categories(parse_data(args.data)['names'])

    cap = cv2.VideoCapture(0)
    frame_queue = Queue()
    preds_queue = Queue()
    cur_dets = None
    frame_lock = Lock()

    proc = Process(target=detect,
                   args=(frame_queue, preds_queue, frame_lock, args))
    proc.start()

    try:

        while (True):
            ret, frame = cap.read()
            frame_lock.acquire()
            while not frame_queue.empty():
                frame_queue.get()

            frame_queue.put(frame)
            frame_lock.release()

            if not preds_queue.empty():
                cur_dets = preds_queue.get()

            if cur_dets is not None and len(cur_dets) > 0:
                frame = draw_detections_opencv(frame, cur_dets[0], categories)

            cv2.imshow('frame', frame)
            cv2.waitKey(1)

    except KeyboardInterrupt:
        print('Interrupted')
        proc.join()
        cap.release()
        cv2.destroyAllWindows()
コード例 #4
0
class MultimodalPatchesCache(object):
    def __init__(self,
                 cache_dir,
                 dataset_dir,
                 dataset_list,
                 cuda,
                 batch_size=500,
                 num_workers=3,
                 renew_frequency=5,
                 rejection_radius_position=0,
                 numpatches=900,
                 numneg=3,
                 pos_thr=50.0,
                 reject=True,
                 mode='train',
                 rejection_radius=3000,
                 dist_type='3D',
                 patch_radius=None,
                 use_depth=False,
                 use_normals=False,
                 use_silhouettes=False,
                 color_jitter=False,
                 greyscale=False,
                 maxres=4096,
                 scale_jitter=False,
                 photo_jitter=False,
                 uniform_negatives=False,
                 needles=0,
                 render_only=False,
                 maxitems=200,
                 cache_once=False):
        super(MultimodalPatchesCache, self).__init__()
        self.cache_dir = cache_dir
        self.dataset_dir = dataset_dir
        #self.images_path = images_path
        self.dataset_list = dataset_list
        self.cuda = cuda
        self.batch_size = batch_size

        self.num_workers = num_workers
        self.renew_frequency = renew_frequency
        self.rejection_radius_position = rejection_radius_position
        self.numpatches = numpatches
        self.numneg = numneg
        self.pos_thr = pos_thr
        self.reject = reject
        self.mode = mode
        self.rejection_radius = rejection_radius
        self.dist_type = dist_type
        self.patch_radius = patch_radius
        self.use_depth = use_depth
        self.use_normals = use_normals
        self.use_silhouettes = use_silhouettes
        self.color_jitter = color_jitter
        self.greyscale = greyscale
        self.maxres = maxres
        self.scale_jitter = scale_jitter
        self.photo_jitter = photo_jitter
        self.uniform_negatives = uniform_negatives
        self.needles = needles
        self.render_only = render_only

        self.cache_done_lock = Lock()
        self.all_done = Value('B', 0)  # 0 is False
        self.cache_done = Value('B', 0)  # 0 is False

        self.wait_for_cache_builder = Event()
        # prepare for wait until initial cache is built
        self.wait_for_cache_builder.clear()
        self.cache_builder_resume = Event()

        self.maxitems = maxitems
        self.cache_once = cache_once

        if self.mode == 'eval':
            self.maxitems = -1
        self.cache_builder = Process(target=self.buildCache,
                                     args=[self.maxitems])
        self.current_cache_build = Value('B', 0)  # 0th cache
        self.current_cache_use = Value('B', 1)  # 1th cache

        self.cache_names = ["cache1", "cache2"]  # constant

        rebuild_cache = True
        if self.mode == 'eval':
            validation_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(validation_dir):
                # we don't need to rebuild validation cache
                # TODO: check if cache is VALID
                rebuild_cache = False
        elif cache_once:
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                # we don't need to rebuild training cache if we are training
                # on limited subset of the training set
                rebuild_cache = False

        if rebuild_cache:
            # clear the caches if they already exist
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                shutil.rmtree(build_dataset_dir)
            use_dataset_dir = os.path.join(
                self.cache_dir, self.cache_names[self.current_cache_use.value])
            if os.path.isdir(use_dataset_dir):
                shutil.rmtree(use_dataset_dir)

            os.makedirs(build_dataset_dir)

            self.cache_builder_resume.set()
            self.cache_builder.start()

            # wait until initial cache is built
            # print("before wait to build")
            # print("wait for cache builder state",
            #       self.wait_for_cache_builder.is_set())
            self.wait_for_cache_builder.wait()
            # print("after wait to build")

        # we have been resumed
        if self.mode != 'eval' and (not self.cache_once):
            # for training, we can set up the cache builder to build
            # the second cache
            self.restart()
        else:
            # else for validation we don't need second cache
            # we just need to switch the built cache to the use cache in order
            # to use it
            tmp = self.current_cache_build.value
            self.current_cache_build.value = self.current_cache_use.value
            self.current_cache_use.value = tmp

        # initialization finished, now this dataset can be used

    def getCurrentCache(self):
        # Lock should not be needed - cache_done is not touched
        # and cache_len is read only for cache in use, which should not
        # been touched by other threads
        # self.cache_done_lock.acquire()
        h5_dataset_filename = os.path.join(
            self.cache_dir, self.cache_names[self.current_cache_use.value])
        # self.cache_done_lock.release()
        return h5_dataset_filename

    def restart(self):
        # print("Restarting - waiting for lock...")
        self.cache_done_lock.acquire()
        # print("Restarting cached dataset...")
        if self.cache_done.value and (not self.cache_once):
            cache_changed = True
            tmp_cache_name = self.current_cache_use.value
            self.current_cache_use.value = self.current_cache_build.value
            self.current_cache_build.value = tmp_cache_name
            # clear the old cache if exists
            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            if os.path.isdir(build_dataset_dir):
                shutil.rmtree(build_dataset_dir)
            os.makedirs(build_dataset_dir)
            self.cache_done.value = 0  # 0 is False
            self.cache_builder_resume.set()
            # print("Switched cache to: ",
            #       self.cache_names[self.current_cache_use.value]
            # )
        else:
            cache_changed = False
            # print(
            #     "New cache not ready, continuing with old cache:",
            #     self.cache_names[self.current_cache_use.value]
            # )
        all_done_value = self.all_done.value
        self.cache_done_lock.release()
        # returns true if no more items are available to be loaded
        # this object should be destroyed and new dataset should be created
        # in order to start over.
        return cache_changed, all_done_value

    def buildCache(self, limit):
        # print("Building cache: ",
        #       self.cache_names[self.current_cache_build.value]
        # )
        dataset = MultimodalPatchesDatasetAll(
            self.dataset_dir,
            self.dataset_list,
            rejection_radius_position=self.rejection_radius_position,
            #self.images_path, list=train_sampled,
            numpatches=self.numpatches,
            numneg=self.numneg,
            pos_thr=self.pos_thr,
            reject=self.reject,
            mode=self.mode,
            rejection_radius=self.rejection_radius,
            dist_type=self.dist_type,
            patch_radius=self.patch_radius,
            use_depth=self.use_depth,
            use_normals=self.use_normals,
            use_silhouettes=self.use_silhouettes,
            color_jitter=self.color_jitter,
            greyscale=self.greyscale,
            maxres=self.maxres,
            scale_jitter=self.scale_jitter,
            photo_jitter=self.photo_jitter,
            uniform_negatives=self.uniform_negatives,
            needles=self.needles,
            render_only=self.render_only)
        n_triplets = len(dataset)

        if limit == -1:
            limit = n_triplets

        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=1,  # self.num_workers
            collate_fn=MultimodalPatchesCache.my_collate)

        qmaxsize = 15
        data_queue = JoinableQueue(maxsize=qmaxsize)

        # cannot load to cuda from background, therefore use cpu device
        preloader_resume = Event()
        preloader = Process(target=MultimodalPatchesCache.generateTrainingData,
                            args=(data_queue, dataset, dataloader,
                                  self.batch_size, qmaxsize, preloader_resume,
                                  True, True))
        preloader.do_run_generate = True
        preloader.start()
        preloader_resume.set()

        i_batch = 0
        data = data_queue.get()
        i_batch = data[0]

        counter = 0
        while i_batch != -1:

            self.cache_builder_resume.wait()

            build_dataset_dir = os.path.join(
                self.cache_dir,
                self.cache_names[self.current_cache_build.value])
            batch_fname = os.path.join(build_dataset_dir,
                                       'batch_' + str(counter) + '.pt')

            # print("ibatch", i_batch,
            #        "___data___", data[3].shape, data[6].shape)

            anchor = data[1]
            pos = data[2]
            neg = data[3]
            anchor_r = data[4]
            pos_p = data[5]
            neg_p = data[6]
            c1 = data[7]
            c2 = data[8]
            cneg = data[9]
            id = data[10]

            if not (self.use_depth or self.use_normals):
                #no need to store image data as float, convert to uint
                anchor = (anchor * 255.0).to(torch.uint8)
                pos = (pos * 255.0).to(torch.uint8)
                neg = (neg * 255.0).to(torch.uint8)
                anchor_r = (anchor_r * 255.0).to(torch.uint8)
                pos_p = (pos_p * 255.0).to(torch.uint8)
                neg_p = (neg_p * 255.0).to(torch.uint8)

            tosave = {
                'anchor': anchor,
                'pos': pos,
                'neg': neg,
                'anchor_r': anchor_r,
                'pos_p': pos_p,
                'neg_p': neg_p,
                'c1': c1,
                'c2': c2,
                'cneg': cneg,
                'id': id
            }

            try:
                torch.save(tosave, batch_fname)
                torch.load(batch_fname)
                counter += 1
            except Exception as e:
                print("Could not save ",
                      batch_fname,
                      ", due to:",
                      e,
                      "skipping...",
                      file=sys.stderr)
                if os.path.isfile(batch_fname):
                    os.remove(batch_fname)

            data_queue.task_done()

            if counter >= limit:
                self.cache_done_lock.acquire()
                self.cache_done.value = 1  # 1 is True
                self.cache_done_lock.release()
                counter = 0
                # sleep until calling thread wakes us
                self.cache_builder_resume.clear()
                # resume calling thread so that it can work
                self.wait_for_cache_builder.set()

            data = data_queue.get()
            i_batch = data[0]
            #print("ibatch", i_batch)

        data_queue.task_done()

        self.cache_done_lock.acquire()
        self.cache_done.value = 1  # 1 is True
        self.all_done.value = 1
        print("Cache done ALL")
        self.cache_done_lock.release()
        # resume calling thread so that it can work
        self.wait_for_cache_builder.set()
        preloader.join()
        preloader = None
        data_queue = None

    @staticmethod
    def loadBatch(sample_batched, mode, device, keep_all=False):
        if mode == 'eval':
            coords1 = sample_batched[6]
            coords2 = sample_batched[7]
            coords_neg = sample_batched[8]
            keep = sample_batched[10]
            item_id = sample_batched[11]
        else:
            coords1 = sample_batched[6]
            coords2 = sample_batched[7]
            coords_neg = sample_batched[8]
            keep = sample_batched[9]
            item_id = sample_batched[10]
        if keep_all:
            # requested to return fill batch
            batchsize = sample_batched[0].shape[0]
            keep = torch.ones(batchsize).byte()
        keep = keep.reshape(-1)
        keep = keep.bool()
        anchor = sample_batched[0]
        pos = sample_batched[1]
        neg = sample_batched[2]

        # swapped photo to render
        anchor_r = sample_batched[3]
        pos_p = sample_batched[4]
        neg_p = sample_batched[5]

        anchor = anchor[keep].to(device)
        pos = pos[keep].to(device)
        neg = neg[keep].to(device)

        anchor_r = anchor_r[keep]
        pos_p = pos_p[keep]
        neg_p = neg_p[keep]

        coords1 = coords1[keep]
        coords2 = coords2[keep]
        coords_neg = coords_neg[keep]
        item_id = item_id[keep]
        return anchor, pos, neg, anchor_r, pos_p, neg_p, coords1, coords2, \
            coords_neg, item_id

    @staticmethod
    def generateTrainingData(queue,
                             dataset,
                             dataloader,
                             batch_size,
                             qmaxsize,
                             resume,
                             shuffle=True,
                             disable_tqdm=False):
        local_buffer_a = []
        local_buffer_p = []
        local_buffer_n = []

        local_buffer_ar = []
        local_buffer_pp = []
        local_buffer_np = []

        local_buffer_c1 = []
        local_buffer_c2 = []
        local_buffer_cneg = []
        local_buffer_id = []
        nbatches = 10
        # cannot load to cuda in batckground process!
        device = torch.device('cpu')

        buffer_size = min(qmaxsize * batch_size, nbatches * batch_size)
        bidx = 0
        for i_batch, sample_batched in enumerate(dataloader):
            # tqdm(dataloader, disable=disable_tqdm)
            resume.wait()
            anchor, pos, neg, anchor_r, \
                pos_p, neg_p, c1, c2, cneg, id = \
                MultimodalPatchesCache.loadBatch(
                    sample_batched, dataset.mode, device
                )
            if anchor.shape[0] == 0:
                continue
            local_buffer_a.extend(list(anchor))  # [:current_batches]
            local_buffer_p.extend(list(pos))
            local_buffer_n.extend(list(neg))

            local_buffer_ar.extend(list(anchor_r))
            local_buffer_pp.extend(list(pos_p))
            local_buffer_np.extend(list(neg_p))

            local_buffer_c1.extend(list(c1))
            local_buffer_c2.extend(list(c2))
            local_buffer_cneg.extend(list(cneg))
            local_buffer_id.extend(list(id))
            if len(local_buffer_a) >= buffer_size:
                if shuffle:
                    local_buffer_a, local_buffer_p, local_buffer_n, \
                        local_buffer_ar, local_buffer_pp, local_buffer_np, \
                        local_buffer_c1, local_buffer_c2, local_buffer_cneg, \
                        local_buffer_id = sklearn.utils.shuffle(
                            local_buffer_a,
                            local_buffer_p,
                            local_buffer_n,
                            local_buffer_ar,
                            local_buffer_pp,
                            local_buffer_np,
                            local_buffer_c1,
                            local_buffer_c2,
                            local_buffer_cneg,
                            local_buffer_id
                        )
                curr_nbatches = int(np.floor(len(local_buffer_a) / batch_size))
                for i in range(0, curr_nbatches):
                    queue.put([
                        bidx,
                        torch.stack(local_buffer_a[:batch_size]),
                        torch.stack(local_buffer_p[:batch_size]),
                        torch.stack(local_buffer_n[:batch_size]),
                        torch.stack(local_buffer_ar[:batch_size]),
                        torch.stack(local_buffer_pp[:batch_size]),
                        torch.stack(local_buffer_np[:batch_size]),
                        torch.stack(local_buffer_c1[:batch_size]),
                        torch.stack(local_buffer_c2[:batch_size]),
                        torch.stack(local_buffer_cneg[:batch_size]),
                        torch.stack(local_buffer_id[:batch_size])
                    ])
                    del local_buffer_a[:batch_size]
                    del local_buffer_p[:batch_size]
                    del local_buffer_n[:batch_size]
                    del local_buffer_ar[:batch_size]
                    del local_buffer_pp[:batch_size]
                    del local_buffer_np[:batch_size]
                    del local_buffer_c1[:batch_size]
                    del local_buffer_c2[:batch_size]
                    del local_buffer_cneg[:batch_size]
                    del local_buffer_id[:batch_size]
                    bidx += 1
        remaining_batches = len(local_buffer_a) // batch_size
        for i in range(0, remaining_batches):
            queue.put([
                bidx,
                torch.stack(local_buffer_a[:batch_size]),
                torch.stack(local_buffer_p[:batch_size]),
                torch.stack(local_buffer_n[:batch_size]),
                torch.stack(local_buffer_ar[:batch_size]),
                torch.stack(local_buffer_pp[:batch_size]),
                torch.stack(local_buffer_np[:batch_size]),
                torch.stack(local_buffer_c1[:batch_size]),
                torch.stack(local_buffer_c2[:batch_size]),
                torch.stack(local_buffer_cneg[:batch_size]),
                torch.stack(local_buffer_id[:batch_size])
            ])
            del local_buffer_a[:batch_size]
            del local_buffer_p[:batch_size]
            del local_buffer_n[:batch_size]
            del local_buffer_ar[:batch_size]
            del local_buffer_pp[:batch_size]
            del local_buffer_np[:batch_size]
            del local_buffer_c1[:batch_size]
            del local_buffer_c2[:batch_size]
            del local_buffer_cneg[:batch_size]
            del local_buffer_id[:batch_size]
        ra = torch.randn(batch_size, 3, 64, 64)
        queue.put([-1, ra, ra, ra])
        queue.join()

    @staticmethod
    def my_collate(batch):
        batch = list(filter(lambda x: x is not None, batch))
        return default_collate(batch)
コード例 #5
0
class lazy_array_loader(object):
    """
    Arguments:
        path: path to directory where array entries are concatenated into one big string file
            and the .len file are located
        data_type (str): Some datsets have multiple fields that are stored in different paths.
            `data_type` specifies which of these fields to load in this class
        mem_map  (boolean): Specifies whether to memory map file `path`
        map_fn (callable): Fetched strings are passed through map_fn before being returned.

    Example of lazy loader directory structure:
    file.json
    file.lazy/
        data_type1
        data_type1.len.pkl
        data_type2
        data_type2.len.pkl
    """
    def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
        lazypath = get_lazy_path(path)
        datapath = os.path.join(lazypath, data_type)
        #get file where array entries are concatenated into one big string
        self._file = open(datapath, 'rb')
        self.file = self._file
        #memory map file if necessary
        self.mem_map = mem_map
        if self.mem_map:
            self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
        lenpath = os.path.join(lazypath, data_type+'.len.pkl')
        self.lens = pkl.load(open(lenpath, 'rb'))
        self.ends = list(accumulate(self.lens))
        self.dumb_ends = list(self.ends)
        self.read_lock = Lock()
        self.process_fn = map_fn
        self.map_fn = map_fn
        self._tokenizer = None

    def SetTokenizer(self, tokenizer):
        """
        logic to set and remove (set to None) tokenizer.
        combines preprocessing/tokenization into one callable.
        """
        if tokenizer is None:
            if not hasattr(self, '_tokenizer'):
                self._tokenizer = tokenizer
        else:
            self._tokenizer = tokenizer
        self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn)

    def GetTokenizer(self):
        return self._tokenizer

    def __getitem__(self, index):
        """
        read file and splice strings based on string ending array `self.ends`
        """
        if not isinstance(index, slice):
            if index == 0:
                start = 0
            else:
                start = self.ends[index-1]
            end = self.ends[index]
            rtn = self.file_read(start, end)
            if self.map_fn is not None:
                return self.map_fn(rtn)
        else:
            # if slice, fetch strings with 1 diskread and then splice in memory
            chr_lens = self.ends[index]
            if index.start == 0 or index.start is None:
                start = 0
            else:
                start = self.ends[index.start-1]
            stop = chr_lens[-1]
            strings = self.file_read(start, stop)
            rtn = split_strings(strings, start, chr_lens)
            if self.map_fn is not None:
                return self.map_fn([s for s in rtn])
        return rtn

    def __len__(self):
        return len(self.ends)

    def file_read(self, start=0, end=None):
        """read specified portion of file"""

        # atomic reads to avoid race conditions with multiprocess dataloader
        self.read_lock.acquire()
        # seek to start of file read
        self.file.seek(start)
        # read to end of file if no end point provided
        if end is None:
            rtn = self.file.read()
        #else read amount needed to reach end point
        else:
            rtn = self.file.read(end-start)
        self.read_lock.release()
        #TODO: @raulp figure out mem map byte string bug
        #if mem map'd need to decode byte string to string
        rtn = rtn.decode('utf-8', 'ignore')
        # rtn = str(rtn)
        if self.mem_map:
            rtn = rtn.decode('unicode_escape')
        return rtn
class Player(QWidget):
    def __init__(self, external_mutex, parent=None):
        super(Player, self).__init__(parent)
        self.init_UI()
        self.init_data()
        self.init_process()
        self.init_connect()
        self.init_offline()
        self.external_mutex = external_mutex

    def init_UI(self):
        self.layout1 = QtWidgets.QGridLayout()  # 创建主部件的网格布局
        self.label_screen = QtWidgets.QLabel(self)  # 用于展示图片的label
        self.btn_start = QtWidgets.QToolButton()
        self.btn_pause = QtWidgets.QToolButton()
        self.btn_close = QtWidgets.QToolButton()
        self.btn_start.setIcon(QIcon('GUI/resources/icons/play.png'))
        self.btn_start.setIconSize(QSize(30, 30))
        self.btn_pause.setIcon(QIcon('GUI/resources/icons/pause.png'))
        self.btn_pause.setIconSize(QSize(30, 30))
        self.btn_close.setIcon(QIcon('GUI/resources/icons/close.png'))
        self.btn_close.setIconSize(QSize(30, 30))

        self.setStyleSheet('''
            QToolButton{border:none;color:red;}
            QToolButton#left_label{
                border:none;
                border-bottom:1px solid white;
                font-size:18px;
                font-weight:700;
                font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
            }
            QToolButton#left_button:hover{border-left:4px solid red;font-weight:700;}
        ''')

        self.label_info = QtWidgets.QLabel()
        self.label_info.setText("[空闲]")
        pe = QPalette()
        pe.setColor(QPalette.WindowText, Qt.white)
        self.label_info.setPalette(pe)
        self.label_info.setFont(QFont("Microsoft YaHei", 10, QFont.Bold))
        self.progressBar = QtWidgets.QSlider()
        self.progressBar.setOrientation(QtCore.Qt.Horizontal)
        self.layout1.addWidget(self.label_info, 0, 0, 1, 7)
        self.layout1.addWidget(self.btn_close, 0, 7, 1, 1)
        self.layout1.addWidget(self.label_screen, 1, 0, 7, 8)
        self.layout1.addWidget(self.btn_start, 8, 0, 1, 1)
        self.layout1.addWidget(self.btn_pause, 8, 1, 1, 1)
        self.layout1.addWidget(self.progressBar, 8, 2, 1, 6)
        self.bottom_widgets = [
            self.btn_start, self.btn_pause, self.progressBar
        ]
        self.label_screen.setScaledContents(True)
        self.label_screen.setPixmap(QPixmap("GUI/resources/helmet.jpg"))
        self.setLayout(self.layout1)  # 设置窗口主部件布局为网格布局

    def init_data(self):
        self.is_working = False
        self.semaphore = True
        self.is_change_bar = Value(
            c_bool, False)  #whether user has dragged the slider,default: False

        self.frame_index = Value('i', 0)
        self.share_lock = Lock()  #shared lock for frame_index
        self.share_lock2 = Lock()  # shared lock for frame_index

        self.mutex = threading.Lock()

        self.timer = QTimer(self)  # used for the updating of progress bar
        self.temp_timer = QTimer(
            self)  #used for detecting whether the frame_total is given.
        self.frame_total = Value('i', -1)
        self.playable = Value(c_bool, True)
        self.is_working = Value(c_bool, False)
        manager = Manager()
        self.play_src = manager.Value(c_char_p, '0')  #用于记录播放的视频地址
        self.mode = None  # 'online' or 'offline'

    def init_connect(self):
        self.btn_pause.clicked.connect(self.pause)
        self.btn_close.clicked.connect(self.close)
        self.btn_start.clicked.connect(self.play)

    def init_offline_connect(self):

        self.timer.timeout.connect(self.update_progressBar)
        self.timer.start(50)  # update the progressbar value every 50ms
        self.temp_timer.timeout.connect(self.set_MaxValue)
        self.temp_timer.start(50)

        #progressbar
        self.progressBar.sliderPressed.connect(
            self.lockBar
        )  # when the user is dragging the slider, stop updating the value
        self.progressBar.sliderReleased.connect(self.change_progressBar)

    def init_online_connect(self):
        pass

    def init_process(self):
        self.origin_img_q = mp.Queue(maxsize=2)
        self.result_img_q = mp.Queue(maxsize=4)
        self.p_detector = Process(target=detector,
                                  args=(self.origin_img_q, self.result_img_q))
        self.p_detector.start()
        self.img_fetcher = Process(target=play,
                                   args=(self.origin_img_q, self.frame_index,
                                         self.share_lock, self.frame_total,
                                         self.is_change_bar, self.playable,
                                         self.is_working, self.play_src))
        self.img_fetcher.start()

    def lockBar(self):
        self.share_lock.acquire()
        self.mutex.acquire()
        self.semaphore = False
        self.mutex.release()

    def change_progressBar(self):

        self.frame_index.value = self.progressBar.value()
        self.is_change_bar.value = True
        self.share_lock.release()
        self.mutex.acquire()
        self.semaphore = True
        self.mutex.release()

    def set_MaxValue(self):

        if self.frame_total.value is not -1:  #只执行一次
            self.progressBar.setMaximum(self.frame_total.value)
            self.temp_timer.disconnect()

    def update_progressBar(self):
        self.mutex.acquire()
        if self.semaphore:
            self.progressBar.setValue(self.frame_index.value)
        self.mutex.release()

    def close(self):
        self.is_working.value = False
        self.label_screen.setPixmap(QPixmap("GUI/resources/helmet.jpg"))
        time.sleep(0.1)
        self.label_info.setText('空闲')

    def play(self):
        self.playable.value = True
        temp_str = self.label_info.text()
        if '[暂停]' in temp_str:
            temp_str = temp_str.replace('[暂停]', '')
        self.label_info.setText(temp_str)

    def pause(self):
        self.playable.value = False
        temp_str = self.label_info.text()
        if '[暂停]' not in temp_str:
            temp_str = '[暂停]' + temp_str
        self.label_info.setText(temp_str)

    def display(self):
        while True:
            start_time = time.time()

            if self.is_working.value:
                self.external_mutex.acquire()
                if not self.result_img_q.empty():
                    prev = time.time()
                    show = self.result_img_q.get()
                    post = time.time()
                    # print(datetime.timedelta(seconds=post - prev))

                    showImage = QImage(show.data, show.shape[1], show.shape[0],
                                       QImage.Format_RGB888)  # 转换成QImage类型
                    self.label_screen.setScaledContents(True)

                    self.label_screen.setPixmap(
                        QPixmap.fromImage(showImage))  #
                    print("FPS: ",
                          1.0 / (time.time() -
                                 start_time))  # FPS = 1 / time to process loop

                self.external_mutex.release()

            else:
                time.sleep(0.1)

    def init_offline(self, video_path=None):
        """

        :param video_path:  要离线播放的video地址
        :return:
        """

        self.init_offline_connect()

        self.offline_video_thread = threading.Thread(target=self.display)
        self.offline_video_thread.start()

    def start_online(self, play_src):
        self.play_src.value = play_src  #path of online camera
        play_src = '0'
        self.label_info.setText('[在线模式]:' + play_src)
        self.is_working.value = True
        self.playable.value = True
        # for i in self.bottom_widgets:  # delete all the left_widgets and then replace them with new ones
        #     i.setVisible(False)
        #     self.layout1.removeWidget(i)
        # self.layout1.addWidget(self.btn_start, 8, 2, 1, 1)
        # self.layout1.addWidget(self.btn_pause, 8, 4, 1, 1)
        # self.btn_start.setVisible(True)
        # self.btn_pause.setVisible(True)
        self.progressBar.setVisible(False)

    def restart_online(self, play_src):
        self.play_src.value = play_src
        play_src = '0'
        self.is_working.value = False
        self.playable.value = False
        self.label_info.setText('[在线模式]:' + play_src)
        self.progressBar.setVisible(False)
        time.sleep(0.1)
        self.is_working.value = True
        self.playable.value = True

    def start_offline(self, play_src):
        self.play_src.value = play_src  # path of online camera
        self.label_info.setText('[离线模式]:' + play_src)
        self.is_working.value = True
        self.playable.value = True
        self.progressBar.setVisible(True)

    def restart_offline(self, play_src):
        self.play_src.value = play_src
        self.is_working.value = False
        self.playable.value = False
        self.label_info.setText('[离线模式]:' + play_src)
        self.progressBar.setVisible(True)
        time.sleep(0.1)
        self.is_working.value = True
        self.playable.value = True