Пример #1
0
    pose_model.to(args.device)
    pose_model.eval()

    # Init data writer
    queueSize = 2 if mode == 'webcam' else args.qsize
    writer = DataWriter(cfg, args, save_video=False,
                        queueSize=queueSize).start()
    data_len = det_loader.length
    im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

    batchSize = args.posebatch
    try:
        for i in im_names_desc:
            with torch.no_grad():
                (inps, orig_img, im_name, boxes, scores, ids,
                 cropped_boxes) = det_loader.read()
                if orig_img is None:
                    break
                if boxes is None or boxes.nelement() == 0:
                    writer.save(None, None, None, None, None, orig_img,
                                os.path.basename(im_name))
                    continue
                # Pose Estimation
                inps = inps.to(args.device)
                datalen = inps.size(0)
                leftover = 0
                if (datalen) % batchSize:
                    leftover = 1
                num_batches = datalen // batchSize + leftover
                hm = []
                for j in range(num_batches):
Пример #2
0
    def run(self):

        if os.path.isfile(self.video):
            mode, input_source = 'video', self.video
        else:
            raise IOError(
                'Error: --video must refer to a video file, not directory.')

        if not os.path.exists(self.outputpath):
            os.makedirs(self.outputpath)

        det_loader = DetectionLoader(input_source,
                                     get_detector(self),
                                     self.cfg,
                                     self,
                                     batchSize=self.detbatch,
                                     mode=mode,
                                     queueSize=self.qsize)
        det_worker = det_loader.start()

        # Load pose model
        pose_model = builder.build_sppe(self.cfg.MODEL,
                                        preset_cfg=self.cfg.DATA_PRESET)

        print(f'Loading pose model from {self.checkpoint}...')
        pose_model.load_state_dict(
            torch.load(self.checkpoint, map_location=self.device))

        if self.pose_track:
            tracker = Tracker(tcfg, self)

        pose_model.to(self.device)
        pose_model.eval()

        if self.save_video:
            from alphapose.utils.writer import DEFAULT_VIDEO_SAVE_OPT as video_save_opt
            video_save_opt['savepath'] = self.outputpath + os.path.basename(
                self.video)
            video_save_opt.update(det_loader.videoinfo)
            writer = DataWriter(self.cfg,
                                self,
                                save_video=True,
                                video_save_opt=video_save_opt,
                                queueSize=self.qsize).start()
        else:
            writer = DataWriter(self.cfg,
                                self,
                                save_video=False,
                                queueSize=self.qsize).start()

        data_len = det_loader.length
        im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

        batchSize = self.posebatch

        try:
            for i in im_names_desc:
                start_time = getTime()
                with torch.no_grad():
                    (inps, orig_img, im_name, boxes, scores, ids,
                     cropped_boxes) = det_loader.read()
                    if orig_img is None:
                        break
                    if boxes is None or boxes.nelement() == 0:
                        writer.save(None, None, None, None, None, orig_img,
                                    os.path.basename(im_name))
                        continue

                    # Pose Estimation
                    inps = inps.to(self.device)
                    datalen = inps.size(0)

                    leftover = 0
                    if (datalen) % batchSize:
                        leftover = 1
                    num_batches = datalen // batchSize + leftover

                    hm = []
                    for j in range(num_batches):
                        inps_j = inps[j * batchSize:min((j + 1) *
                                                        batchSize, datalen)]
                        hm_j = pose_model(inps_j)
                        hm.append(hm_j)

                    hm = torch.cat(hm)
                    #hm = hm.cpu()
                    if self.pose_track:
                        boxes, scores, ids, hm, cropped_boxes = track(
                            tracker, self, orig_img, inps, boxes, hm,
                            cropped_boxes, im_name, scores)
                    writer.save(boxes, scores, ids, hm, cropped_boxes,
                                orig_img, os.path.basename(im_name))

            while (writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(writer.count()) + ' images in the queue...')
            writer.stop()
            det_loader.stop()

        except KeyboardInterrupt:
            det_loader.terminate()
            while (writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(writer.count()) + ' images in the queue...')
            writer.stop()

        self.all_results = writer.results()
        self._save()
Пример #3
0
    def predict(self, image, img_name):
        args = self.args
        # Load detection loader
        det_loader = DetectionLoader(self.input_source, [img_name], [image],
                                     get_detector(args),
                                     self.cfg,
                                     args,
                                     batchSize=args.detbatch,
                                     mode=self.mode).start()

        # Init data writer
        queueSize = args.qsize
        self.writer = DataWriter(self.cfg,
                                 args,
                                 save_video=False,
                                 queueSize=queueSize).start()

        runtime_profile = {'dt': [], 'pt': [], 'pn': []}

        data_len = det_loader.length
        im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

        batchSize = args.posebatch
        if args.flip:
            batchSize = int(batchSize / 2)

        try:
            for i in im_names_desc:
                start_time = getTime()
                with torch.no_grad():
                    (inps, orig_img, im_name, boxes, scores, ids,
                     cropped_boxes) = det_loader.read()
                    if orig_img is None:
                        break
                    if boxes is None or boxes.nelement() == 0:
                        self.writer.save(None, None, None, None, None,
                                         orig_img, os.path.basename(im_name))
                        continue
                    if args.profile:
                        ckpt_time, det_time = getTime(start_time)
                        runtime_profile['dt'].append(det_time)
                    # Pose Estimation
                    inps = inps.to(args.device)
                    datalen = inps.size(0)
                    leftover = 0
                    if (datalen) % batchSize:
                        leftover = 1
                    num_batches = datalen // batchSize + leftover
                    hm = []
                    for j in range(num_batches):
                        inps_j = inps[j * batchSize:min((j + 1) *
                                                        batchSize, datalen)]
                        if args.flip:
                            inps_j = torch.cat((inps_j, flip(inps_j)))
                        hm_j = self.pose_model(inps_j)
                        if args.flip:
                            hm_j_flip = flip_heatmap(hm_j[int(len(hm_j) / 2):],
                                                     det_loader.joint_pairs,
                                                     shift=True)
                            hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 2
                        hm.append(hm_j)
                    hm = torch.cat(hm)
                    if args.profile:
                        ckpt_time, pose_time = getTime(ckpt_time)
                        runtime_profile['pt'].append(pose_time)
                    hm = hm.cpu()
                    self.writer.save(boxes, scores, ids, hm, cropped_boxes,
                                     orig_img, os.path.basename(im_name))

                    if args.profile:
                        ckpt_time, post_time = getTime(ckpt_time)
                        runtime_profile['pn'].append(post_time)

                if args.profile:
                    # TQDM
                    im_names_desc.set_description(
                        'det time: {dt:.4f} | pose time: {pt:.4f} | post processing: {pn:.4f}'
                        .format(dt=np.mean(runtime_profile['dt']),
                                pt=np.mean(runtime_profile['pt']),
                                pn=np.mean(runtime_profile['pn'])))
            while (self.writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(self.writer.count()) + ' images in the queue...')
            self.writer.stop()
            det_loader.stop()
        except KeyboardInterrupt:
            self.print_finish_info(args)
            # Thread won't be killed when press Ctrl+C
            if args.sp:
                det_loader.terminate()
                while (self.writer.running()):
                    time.sleep(1)
                    print('===========================> Rendering remaining ' +
                          str(self.writer.count()) + ' images in the queue...')
                self.writer.stop()
            else:
                # subprocesses are killed, manually clear queues
                self.writer.commit()
                self.writer.clear_queues()
                # det_loader.clear_queues()
        final_result = self.writer.results()
        return write_json(final_result,
                          args.outputpath,
                          form=args.format,
                          for_eval=args.eval)