Esempio n. 1
0
def run_fastpose(args):
    """
        运行alphapose模型,计算它的推理时间
    """
    cfg = update_config(args.fastpose_cfg)
    # 创建fastpose的模型
    pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)
    # 加载权重
    print('Loading pose model from %s...' % (args.checkpoint, ))
    pose_model.load_state_dict(
        torch.load(args.checkpoint, map_location=args.device))
    pose_model = pose_model.to('cuda:0')
    input_data = torch.randn(arg.batch, 3, 256, 192,
                             dtype=torch.float32).to('cuda:0')
    # 转成numpy,用于对比加速结果
    output_data_pytorch = pose_model(input_data).cpu().detach().numpy()
    # 让模型跑100次,然后计算时间
    nRound = 100
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(nRound):
        pose_model(input_data)
    torch.cuda.synchronize()
    time_pytorch = (time.time() - t0) / nRound
    # print('PyTorch time:', time_pytorch)
    return time_pytorch, output_data_pytorch
Esempio n. 2
0
def preset_model(cfg):
    model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)

    if cfg.MODEL.TRY_LOAD:
        logger.info(f'Loading model from {cfg.MODEL.TRY_LOAD}...')

        map_location = None if len(opt.gpus) > 0 else torch.device('cpu')
        pretrained_state = torch.load(cfg.MODEL.TRY_LOAD, map_location=map_location)

        model_state = model.state_dict()
        pretrained_state = {k: v for k, v in pretrained_state.items()
                            if k in model_state and v.size() == model_state[k].size()}

        model_state.update(pretrained_state)
        model.load_state_dict(model_state)
    elif cfg.MODEL.PRETRAINED:
        logger.info(f'Loading model from {cfg.MODEL.PRETRAINED}...')

        map_location = None if len(opt.gpus) > 0 else torch.device('cpu')

        model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED, map_location=map_location))
    else:
        logger.info('Create new model')
        logger.info('=> init weights')
        model._initialize()

    return model
Esempio n. 3
0
    def __init__(self, args):
        self.cfg = update_config(args.cfg)

        args.gpus = [int(i) for i in args.gpus.split(',')
                     ] if torch.cuda.device_count() >= 1 else [-1]
        args.device = torch.device(
            "cuda:" + str(args.gpus[0]) if args.gpus[0] >= 0 else "cpu")
        args.detbatch = args.detbatch * len(args.gpus)
        args.posebatch = args.posebatch * len(args.gpus)
        args.tracking = (args.detector == 'tracker')

        self.mode, self.input_source = self.check_input(args)

        # Load pose model
        self.pose_model = builder.build_sppe(self.cfg.MODEL,
                                             preset_cfg=self.cfg.DATA_PRESET)

        print(f'Loading pose model from {args.checkpoint}...')
        self.pose_model.load_state_dict(
            torch.load(args.checkpoint, map_location=args.device))

        if len(args.gpus) > 1:
            self.pose_model = torch.nn.DataParallel(self.pose_model,
                                                    device_ids=args.gpus).to(
                                                        args.device)
        else:
            self.pose_model.to(args.device)
        self.pose_model.eval()

        self.args = args
    def __init__(self):
        self.device = try_gpu()
        self.cfg = update_config(
            'configs/coco/resnet/256x192_res50_lr1e-3_1x.yaml')

        self.detector = get_detector({'detector': yolo})
        self.detector.load_model()

        self.pose_net = builder.build_sppe(self.cfg.MODEL,
                                           preset_cfg=self.cfg.DATA_PRESET)
        self.pose_net.load_state_dict(
            torch.load('pretrained_models/fast_res50_256x192.pth',
                       map_location=self.device))

        pose_model.to(self.device)
Esempio n. 5
0
	def __init__(self, args, cfg):
		self.args = args
		self.cfg = cfg

		# Load pose model
		self.pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)

		print(f'Loading pose model from {args.checkpoint}...')
		self.pose_model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))
		self.pose_dataset = builder.retrieve_dataset(cfg.DATASET.TRAIN)

		self.pose_model.to(args.device)
		self.pose_model.eval()
		
		self.det_loader = DetectionLoader(get_detector(self.args), self.cfg, self.args)
Esempio n. 6
0
def build_poser(pose_opt,gpus_id='0'):
    # Load pose model
    pose_model = builder.build_sppe(pose_opt.MODEL, preset_cfg=pose_opt.DATA_PRESET)
    pose_model.cuda()
    # gpus_id = [int(i) for i in gpus_id.split(',')] if torch.cuda.device_count() >= 1 else [-1]
    # device = torch.device("cuda:" + str(gpus_id[0]) if gpus_id[0] >= 0 else "cpu")

    # print(f'Loading alphapose model from {pose_opt.MODEL.checkpoint}...')
    pose_model.load_state_dict(torch.load(pose_opt.MODEL.checkpoint))

    # if len(gpus_id) > 1:
    #     pose_model = torch.nn.DataParallel(pose_model, device_ids=gpus_id).to(device)
    # else:
    #     pose_model.to(device)
    pose_model.eval()

    return pose_model
Esempio n. 7
0
    def assemble_AlphaPose(self):
        from alphapose.models import builder

        with open(self.cfg_file, 'r') as f:
            cfg = EasyDict(yaml.load(f, Loader=yaml.FullLoader))
        device = select_device(cfg['device'])
        model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)
        print(f'Loading weights from {cfg["weights"]}')
        model.load_state_dict(torch.load(cfg['weights'], map_location=device))
        # print(cfg.device, device)
        if len(cfg['device'].split(',')) >= 2:
            model.to(device)
            model = torch.nn.DataParallel(model,
                                          device_ids=cfg['device']).to(device)
        else:
            model.to(device)

        return model, cfg, device
Esempio n. 8
0
def preset_model(cfg):
    model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)

    if cfg.MODEL.PRETRAINED:
        logger.info(f'Loading model from {cfg.MODEL.PRETRAINED}...')
        model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED), strict=False)
    elif cfg.MODEL.TRY_LOAD:
        logger.info(f'Loading model from {cfg.MODEL.TRY_LOAD}...')
        pretrained_state = torch.load(cfg.MODEL.TRY_LOAD)
        model_state = model.state_dict()
        pretrained_state = {
            k: v
            for k, v in pretrained_state.items()
            if k in model_state and v.size() == model_state[k].size()
        }

        model_state.update(pretrained_state)
        model.load_state_dict(model_state)
    else:
        logger.info('Create new model')
        logger.info('=> init weights')
        model._initialize()

    return model
Esempio n. 9
0
def print_finish_info():
    print('===========================> Finish Model Running.')
    if (args.save_img or args.save_video) and not args.vis_fast:
        print(
            '===========================> Rendering remaining images in the queue...'
        )
        print(
            '===========================> If this step takes too long, you can enable the --vis_fast flag to use fast rendering (real-time).'
        )


if __name__ == "__main__":

    # Load pose model
    pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)
    print(f'Loading pose model from {args.checkpoint}...')
    pose_model.load_state_dict(
        torch.load(args.checkpoint, map_location=args.device))

    if len(args.gpus) > 1:
        pose_model = torch.nn.DataParallel(
            pose_model, device_ids=args.gpus).to(args.device)
    else:
        pose_model.to(args.device)
    pose_model.eval()

    for game in ['FM']:
        # Load detected imgs

        input_source = '/datanew/hwb/data/WG_Num/{}/JPEGImages'.format(game)
Esempio n. 10
0
    def run(self):

        if os.path.isfile(self.video):
            mode, input_source = 'video', self.video
        else:
            raise IOError(
                'Error: --video must refer to a video file, not directory.')

        if not os.path.exists(self.outputpath):
            os.makedirs(self.outputpath)

        det_loader = DetectionLoader(input_source,
                                     get_detector(self),
                                     self.cfg,
                                     self,
                                     batchSize=self.detbatch,
                                     mode=mode,
                                     queueSize=self.qsize)
        det_worker = det_loader.start()

        # Load pose model
        pose_model = builder.build_sppe(self.cfg.MODEL,
                                        preset_cfg=self.cfg.DATA_PRESET)

        print(f'Loading pose model from {self.checkpoint}...')
        pose_model.load_state_dict(
            torch.load(self.checkpoint, map_location=self.device))

        if self.pose_track:
            tracker = Tracker(tcfg, self)

        pose_model.to(self.device)
        pose_model.eval()

        if self.save_video:
            from alphapose.utils.writer import DEFAULT_VIDEO_SAVE_OPT as video_save_opt
            video_save_opt['savepath'] = self.outputpath + os.path.basename(
                self.video)
            video_save_opt.update(det_loader.videoinfo)
            writer = DataWriter(self.cfg,
                                self,
                                save_video=True,
                                video_save_opt=video_save_opt,
                                queueSize=self.qsize).start()
        else:
            writer = DataWriter(self.cfg,
                                self,
                                save_video=False,
                                queueSize=self.qsize).start()

        data_len = det_loader.length
        im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

        batchSize = self.posebatch

        try:
            for i in im_names_desc:
                start_time = getTime()
                with torch.no_grad():
                    (inps, orig_img, im_name, boxes, scores, ids,
                     cropped_boxes) = det_loader.read()
                    if orig_img is None:
                        break
                    if boxes is None or boxes.nelement() == 0:
                        writer.save(None, None, None, None, None, orig_img,
                                    os.path.basename(im_name))
                        continue

                    # Pose Estimation
                    inps = inps.to(self.device)
                    datalen = inps.size(0)

                    leftover = 0
                    if (datalen) % batchSize:
                        leftover = 1
                    num_batches = datalen // batchSize + leftover

                    hm = []
                    for j in range(num_batches):
                        inps_j = inps[j * batchSize:min((j + 1) *
                                                        batchSize, datalen)]
                        hm_j = pose_model(inps_j)
                        hm.append(hm_j)

                    hm = torch.cat(hm)
                    #hm = hm.cpu()
                    if self.pose_track:
                        boxes, scores, ids, hm, cropped_boxes = track(
                            tracker, self, orig_img, inps, boxes, hm,
                            cropped_boxes, im_name, scores)
                    writer.save(boxes, scores, ids, hm, cropped_boxes,
                                orig_img, os.path.basename(im_name))

            while (writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(writer.count()) + ' images in the queue...')
            writer.stop()
            det_loader.stop()

        except KeyboardInterrupt:
            det_loader.terminate()
            while (writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(writer.count()) + ' images in the queue...')
            writer.stop()

        self.all_results = writer.results()
        self._save()
Esempio n. 11
0
    def __init__(self, args=None):

        if args is None:

            args = Namespace(
                # Pose config
                pose_cfg='configs/coco/resnet/256x192_res50_lr1e-3_1x.yaml',
                # Pose checkpoint
                pose_checkpoint='pretrained_models/fast_res50_256x192.pth',
                # GPUS
                gpus='0',
                # Detection thresh
                det_thresh=0.5,
                # Detection config
                det_cfg='mmDetection/gfl_x101_611.py',
                # Detection checkpoint
                det_checkpoint='mmDetection/weights.pth',
                # Show clothe color
                clothe_color=True,
                # show bboxes
                showbox=True

            )
    
        
        self.pose_cfg = update_config(args.pose_cfg)
        

        # Device configuration
        args.gpus = [int(i) for i in args.gpus.split(',')] if torch.cuda.device_count() >= 1 else [-1]
        args.device = torch.device("cuda:" + str(args.gpus[0]) if args.gpus[0] >= 0 else "cpu")
        args.tracking = False
        args.pose_track = False

        # Copy args
        self.args = args

        # Preprocess transformation
        pose_dataset = builder.retrieve_dataset(self.pose_cfg.DATASET.TRAIN)
        self.transformation = SimpleTransform(
            pose_dataset, scale_factor=0,
            input_size=self.pose_cfg.DATA_PRESET.IMAGE_SIZE,
            output_size=self.pose_cfg.DATA_PRESET.HEATMAP_SIZE,
            rot=0, sigma=self.pose_cfg.DATA_PRESET.SIGMA,
            train=False, add_dpg=False, gpu_device=args.device)

        self.norm_type = self.pose_cfg.LOSS.get('NORM_TYPE', None)

        # Post process        
        self.heatmap_to_coord = get_func_heatmap_to_coord(self.pose_cfg)


        # Load Detector Model
        self.det_model = init_detector(args.det_cfg, checkpoint=args.det_checkpoint, device=args.device)

        # Load pose model
        self.pose_model = builder.build_sppe(self.pose_cfg.MODEL, preset_cfg=self.pose_cfg.DATA_PRESET)

        print(f'Loading pose model from {args.pose_checkpoint}...')
        self.pose_model.load_state_dict(torch.load(args.pose_checkpoint, map_location=args.device))

        self.pose_model.to(args.device)
        self.pose_model.eval()
Esempio n. 12
0
    def start(self):
        parser = argparse.ArgumentParser(description='AlphaPose Demo')
        parser.add_argument(
            '--cfg',
            type=str,
            required=False,
            help='experiment configure file name',
            default=
            "./AlphaPose/configs/coco/resnet/256x192_res50_lr1e-3_1x.yaml")
        parser.add_argument(
            '--checkpoint',
            type=str,
            required=False,
            help='checkpoint file name',
            default="./AlphaPose/pretrained_models/fast_res50_256x192.pth")
        parser.add_argument('--sp',
                            default=False,
                            action='store_true',
                            help='Use single process for pytorch')
        parser.add_argument('--detector',
                            dest='detector',
                            help='detector name',
                            default="yolo")
        parser.add_argument('--detfile',
                            dest='detfile',
                            help='detection result file',
                            default="")
        parser.add_argument('--indir',
                            dest='inputpath',
                            help='image-directory',
                            default="./media/img")
        parser.add_argument('--list',
                            dest='inputlist',
                            help='image-list',
                            default="")
        parser.add_argument('--image',
                            dest='inputimg',
                            help='image-name',
                            default="")
        parser.add_argument('--outdir',
                            dest='outputpath',
                            help='output-directory',
                            default="./output")
        parser.add_argument('--save_img',
                            default=True,
                            action='store_true',
                            help='save result as image')
        parser.add_argument('--vis',
                            default=False,
                            action='store_true',
                            help='visualize image')
        parser.add_argument('--showbox',
                            default=False,
                            action='store_true',
                            help='visualize human bbox')
        parser.add_argument('--profile',
                            default=False,
                            action='store_true',
                            help='add speed profiling at screen output')
        parser.add_argument(
            '--format',
            type=str,
            help=
            'save in the format of cmu or coco or openpose, option: coco/cmu/open',
            default="open")
        parser.add_argument('--min_box_area',
                            type=int,
                            default=0,
                            help='min box area to filter out')
        parser.add_argument('--detbatch',
                            type=int,
                            default=1,
                            help='detection batch size PER GPU')
        parser.add_argument('--posebatch',
                            type=int,
                            default=30,
                            help='pose estimation maximum batch size PER GPU')
        parser.add_argument(
            '--eval',
            dest='eval',
            default=False,
            action='store_true',
            help=
            'save the result json as coco format, using image index(int) instead of image name(str)'
        )
        parser.add_argument(
            '--gpus',
            type=str,
            dest='gpus',
            default="0",
            help=
            'choose which cuda device to use by index and input comma to use multi gpus, e.g. 0,1,2,3. (input -1 for cpu only)'
        )
        parser.add_argument(
            '--qsize',
            type=int,
            dest='qsize',
            default=1024,
            help=
            'the length of result buffer, where reducing it will lower requirement of cpu memory'
        )
        parser.add_argument('--flip',
                            default=False,
                            action='store_true',
                            help='enable flip testing')
        parser.add_argument('--debug',
                            default=False,
                            action='store_true',
                            help='print detail information')
        """----------------------------- Video options -----------------------------"""
        parser.add_argument('--video',
                            dest='video',
                            help='video-name',
                            default="")
        parser.add_argument('--webcam',
                            dest='webcam',
                            type=int,
                            help='webcam number',
                            default=-1)
        parser.add_argument('--save_video',
                            dest='save_video',
                            help='whether to save rendered video',
                            default=False,
                            action='store_true')
        parser.add_argument('--vis_fast',
                            dest='vis_fast',
                            help='use fast rendering',
                            action='store_true',
                            default=False)
        """----------------------------- Tracking options -----------------------------"""
        parser.add_argument('--pose_flow',
                            dest='pose_flow',
                            help='track humans in video with PoseFlow',
                            action='store_true',
                            default=False)
        parser.add_argument('--pose_track',
                            dest='pose_track',
                            help='track humans in video with reid',
                            action='store_true',
                            default=True)

        args = parser.parse_args()
        cfg = update_config(args.cfg)

        if platform.system() == 'Windows':
            args.sp = True

        args.gpus = [int(i) for i in args.gpus.split(',')
                     ] if torch.cuda.device_count() >= 1 else [-1]
        args.device = torch.device(
            "cuda:" + str(args.gpus[0]) if args.gpus[0] >= 0 else "cpu")
        args.detbatch = args.detbatch * len(args.gpus)
        args.posebatch = args.posebatch * len(args.gpus)
        args.tracking = args.pose_track or args.pose_flow or args.detector == 'tracker'

        if not args.sp:
            torch.multiprocessing.set_start_method('forkserver', force=True)
            torch.multiprocessing.set_sharing_strategy('file_system')

        def check_input():
            # for wecam
            if args.webcam != -1:
                args.detbatch = 1
                return 'webcam', int(args.webcam)

            # for video
            if len(args.video):
                if os.path.isfile(args.video):
                    videofile = args.video
                    return 'video', videofile
                else:
                    raise IOError(
                        'Error: --video must refer to a video file, not directory.'
                    )

            # for detection results
            if len(args.detfile):
                if os.path.isfile(args.detfile):
                    detfile = args.detfile
                    return 'detfile', detfile
                else:
                    raise IOError(
                        'Error: --detfile must refer to a detection json file, not directory.'
                    )

            # for images
            if len(args.inputpath) or len(args.inputlist) or len(
                    args.inputimg):
                inputpath = args.inputpath
                inputlist = args.inputlist
                inputimg = args.inputimg

                if len(inputlist):
                    im_names = open(inputlist, 'r').readlines()
                elif len(inputpath) and inputpath != '/':
                    for root, dirs, files in os.walk(inputpath):
                        im_names = files
                    im_names = natsort.natsorted(im_names)
                elif len(inputimg):
                    args.inputpath = os.path.split(inputimg)[0]
                    im_names = [os.path.split(inputimg)[1]]

                return 'image', im_names

            else:
                raise NotImplementedError

        def print_finish_info():
            print('===========================> Finish Model Running.')
            if (args.save_img or args.save_video) and not args.vis_fast:
                print(
                    '===========================> Rendering remaining images in the queue...'
                )
                print(
                    '===========================> If this step takes too long, you can enable the --vis_fast flag to use fast rendering (real-time).'
                )

        def loop():
            n = 0
            while True:
                yield n
                n += 1

        # dirList = os.listdir(args.inputpath)
        # inDir = args.inputpath
        # outDir = args.outputpath
        # for i in dirList :
        mode, input_source = check_input()
        if not os.path.exists(args.outputpath):
            os.makedirs(args.outputpath)

        # Load detection loader
        if mode == 'webcam':
            det_loader = WebCamDetectionLoader(input_source,
                                               get_detector(args), cfg, args)
            det_worker = det_loader.start()
        elif mode == 'detfile':
            det_loader = FileDetectionLoader(input_source, cfg, args)
            det_worker = det_loader.start()
        else:
            det_loader = DetectionLoader(input_source,
                                         get_detector(args),
                                         cfg,
                                         args,
                                         batchSize=args.detbatch,
                                         mode=mode,
                                         queueSize=args.qsize)
            det_worker = det_loader.start()

        # Load pose model
        pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)

        print(f'Loading pose model from {args.checkpoint}...')
        pose_model.load_state_dict(
            torch.load(args.checkpoint, map_location=args.device))
        pose_dataset = builder.retrieve_dataset(cfg.DATASET.TRAIN)
        if args.pose_track:
            tracker = Tracker(tcfg, args)
        if len(args.gpus) > 1:
            pose_model = torch.nn.DataParallel(pose_model,
                                               device_ids=args.gpus).to(
                                                   args.device)
        else:
            pose_model.to(args.device)
        pose_model.eval()

        runtime_profile = {'dt': [], 'pt': [], 'pn': []}

        # Init data writer
        queueSize = 2 if mode == 'webcam' else args.qsize
        if args.save_video and mode != 'image':
            from alphapose.utils.writer import DEFAULT_VIDEO_SAVE_OPT as video_save_opt
            if mode == 'video':
                video_save_opt['savepath'] = os.path.join(
                    args.outputpath,
                    'AlphaPose_' + os.path.basename(input_source))
            else:
                video_save_opt['savepath'] = os.path.join(
                    args.outputpath,
                    'AlphaPose_webcam' + str(input_source) + '.mp4')
            video_save_opt.update(det_loader.videoinfo)
            writer = DataWriter(cfg,
                                args,
                                save_video=True,
                                video_save_opt=video_save_opt,
                                queueSize=queueSize).start()
        else:
            writer = DataWriter(cfg,
                                args,
                                save_video=False,
                                queueSize=queueSize).start()

        if mode == 'webcam':
            print('Starting webcam demo, press Ctrl + C to terminate...')
            sys.stdout.flush()
            im_names_desc = tqdm(loop())
        else:
            data_len = det_loader.length
            im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

        batchSize = args.posebatch
        if args.flip:
            batchSize = int(batchSize / 2)
        try:
            self.percentage[2] = '관절정보 분석중'
            for i in range(len(im_names_desc)):
                start_time = getTime()
                # print(start_time)
                self.percentage[0] += 1
                #
                with torch.no_grad():
                    (inps, orig_img, im_name, boxes, scores, ids,
                     cropped_boxes) = det_loader.read()
                    if orig_img is None:
                        break
                    if boxes is None or boxes.nelement() == 0:
                        writer.save(None, None, None, None, None, orig_img,
                                    im_name)
                        continue
                    if args.profile:
                        ckpt_time, det_time = getTime(start_time)
                        runtime_profile['dt'].append(det_time)
                    # Pose Estimation
                    inps = inps.to(args.device)
                    datalen = inps.size(0)
                    leftover = 0
                    if (datalen) % batchSize:
                        leftover = 1
                    num_batches = datalen // batchSize + leftover
                    hm = []
                    for j in range(num_batches):
                        inps_j = inps[j * batchSize:min((j + 1) *
                                                        batchSize, datalen)]
                        if args.flip:
                            inps_j = torch.cat((inps_j, flip(inps_j)))
                        hm_j = pose_model(inps_j)
                        if args.flip:
                            hm_j_flip = flip_heatmap(hm_j[int(len(hm_j) / 2):],
                                                     pose_dataset.joint_pairs,
                                                     shift=True)
                            hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 2
                        hm.append(hm_j)
                    hm = torch.cat(hm)
                    if args.profile:
                        ckpt_time, pose_time = getTime(ckpt_time)
                        runtime_profile['pt'].append(pose_time)
                    if args.pose_track:
                        boxes, scores, ids, hm, cropped_boxes = track(
                            tracker, args, orig_img, inps, boxes, hm,
                            cropped_boxes, im_name, scores)
                    hm = hm.cpu()
                    writer.save(boxes, scores, ids, hm, cropped_boxes,
                                orig_img, im_name)
                    if args.profile:
                        ckpt_time, post_time = getTime(ckpt_time)
                        runtime_profile['pn'].append(post_time)

                if args.profile:
                    # TQDM
                    im_names_desc.set_description(
                        'det time: {dt:.4f}  | pose time: {pt:.4f} | post processing: {pn:.4f}'
                        .format(dt=np.mean(runtime_profile['dt']),
                                pt=np.mean(runtime_profile['pt']),
                                pn=np.mean(runtime_profile['pn'])))
            print_finish_info()
            print("마무리 작업중...")
            while (writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(writer.count()) + ' images in the queue...')
            writer.stop()
            det_loader.stop()
            print("작업종료")
        except Exception as e:
            print(repr(e))
            print(
                'An error as above occurs when processing the images, please check it'
            )
            pass
        except KeyboardInterrupt:
            print_finish_info()
            # Thread won't be killed when press Ctrl+C
            if args.sp:
                det_loader.terminate()
                while (writer.running()):
                    time.sleep(1)
                    print('===========================> Rendering remaining ' +
                          str(writer.count()) + ' images in the queue...')
                writer.stop()
            else:
                # subprocesses are killed, manually clear queues

                det_loader.terminate()
                writer.terminate()
                writer.clear_queues()
                det_loader.clear_queues()
Esempio n. 13
0
def transform2onnx(args):
    cfg = update_config(args.cfg)
    # 创建模型
    pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)
    # 加载权重
    print('Loading pose model from %s...' % (args.checkpoint, ))
    pose_model.load_state_dict(
        torch.load(args.checkpoint, map_location=args.device))
    pose_model = pose_model.to('cuda:0')
    input_names = ['input']
    output_names = ['output']
    # 判断batch size时候是变化的
    dynamic = False
    if args.batch_size <= 0:
        dynamic = True
    # batch_size是动态变化的
    if dynamic:
        # 创建虚拟的输入张量
        dummy_input = torch.randn(1,
                                  3,
                                  args.height,
                                  args.width,
                                  dtype=torch.float32).to('cuda:0')
        onnx_file_name = "alphaPose_-1_3_{}_{}_dynamic.onnx".format(
            args.height, args.width)
        # dynamic_axes = {"input": [0, 2, 3]}
        dynamic_axes = {
            "input": {
                0: "batch_size",
                2: "height",
                3: "width"
            },
            "output": {
                0: "batch_size",
                2: "height",
                3: "width"
            }
        }
        # Export the model
        print('Export the onnx model ...')
        torch.onnx.export(pose_model,
                          dummy_input,
                          onnx_file_name,
                          export_params=True,
                          opset_version=11,
                          do_constant_folding=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes)

        print('Onnx model exporting done')
        return onnx_file_name
    else:
        # 创建虚拟的输入张量
        dummy_input = torch.randn(args.batch_size,
                                  3,
                                  args.height,
                                  args.width,
                                  dtype=torch.float32).to('cuda:0')
        onnx_file_name = "alphaPose_{}_3_{}_{}_dynamic.onnx".format(
            args.batch_size, args.height, args.width)
        print('Export the onnx model ...')
        # 将pytorch模型转成onnx模型
        torch.onnx.export(pose_model,
                          dummy_input,
                          onnx_file_name,
                          input_names=input_names,
                          output_names=output_names,
                          verbose=True,
                          opset_version=11)