Beispiel #1
0
    def init_fn(self):
        self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True)
        self.model = hmr(config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)      # feature extraction model
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                            lr = self.options.lr,
                                            weight_decay=0)
        self.smpl = SMPL(config.SMPL_MODEL_DIR,
                         batch_size = 16,
                         create_transl=False).to(self.device)
        # per vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # keypoints loss including 2D and 3D
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # SMPL parameters loss if we have
        self.criterion_regr = nn.MSELoss().to(self.device)

        self.models_dict = {'model':self.model}
        self.optimizers_dict = {'optimizer':self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH
        # initialize MVSMPLify
        self.mvsmplify = MVSMPLify(step_size=1e-2, batch_size=16, num_iters=100,focal_length=self.focal_length)
        print(self.options.pretrained_checkpoint)
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(checkpoint_file = self.options.pretrained_checkpoint)
        #load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)
        # create renderer
        self.renderer = Renderer(focal_length=self.focal_length, img_res = 224, faces=self.smpl.faces)
Beispiel #2
0
    def init_fn(self):
        self.train_ds = MixedDataset(self.options, ignore_3d=self.options.ignore_3d, is_train=True)

        self.model = hmr(config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=self.options.lr,
                                          weight_decay=0)
        self.smpl = SMPL(config.SMPL_MODEL_DIR,
                         batch_size=self.options.batch_size,
                         create_transl=False).to(self.device)
        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.models_dict = {'model': self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH
        self.conf_thresh = self.options.conf_thresh

        # Initialize SMPLify fitting module
        self.smplify = SMPLify(step_size=1e-2, batch_size=self.options.batch_size, num_iters=self.options.num_smplify_iters, focal_length=self.focal_length, prior_mul=0.1, conf_thresh=self.conf_thresh)
        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(checkpoint_file=self.options.pretrained_checkpoint)

        # Load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)

        # Create renderer
        self.renderer = Renderer(focal_length=self.focal_length, img_res=self.options.img_res, faces=self.smpl.faces)
Beispiel #3
0
def bbox_from_json(bbox_file):
    """Get center and scale of bounding box from bounding box annotations.
    The expected format is [top_left(x), top_left(y), width, height].
    """
    with open(bbox_file, 'r') as f:
        bbox = np.array(json.load(f)['bbox']).astype(np.float32)
    ul_corner = bbox[:2]
    center = ul_corner + 0.5 * bbox[2:]
    # Load pretrained model
    model = hmr(config.SMPL_MEAN_PARAMS).to(device)
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model'], strict=False)

    # Load SMPL model
    smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=1,
                create_transl=False).to(device)
    model.eval()

    # Setup renderer for visualization
    renderer = Renderer(focal_length=constants.FOCAL_LENGTH,
                        img_res=constants.IMG_RES,
                        faces=smpl.faces)

    # Preprocess input image and generate predictions
    img, norm_img = process_image(args.img,
                                  args.bbox,
                                  args.openpose,
                                  input_res=constants.IMG_RES)
    with torch.no_grad():
        pred_rotmat, pred_betas, pred_camera = model(norm_img.to(device))
        pred_output = smpl(betas=pred_betas,
                           body_pose=pred_rotmat[:, 1:],
                           global_orient=pred_rotmat[:, 0].unsqueeze(1),
                           pose2rot=False)
        pred_vertices = pred_output.vertices

    # Calculate camera parameters for rendering
    camera_translation = torch.stack([
        pred_camera[:, 1], pred_camera[:, 2], 2 * constants.FOCAL_LENGTH /
        (constants.IMG_RES * pred_camera[:, 0] + 1e-9)
    ],
                                     dim=-1)
    camera_translation = camera_translation[0].cpu().numpy()
    pred_vertices = pred_vertices[0].cpu().numpy()
    img = img.permute(1, 2, 0).cpu().numpy()

    width = max(bbox[2], bbox[3])
    scale = width / 200.0
    # make sure the bounding box is rectangular
    return center, scale
Beispiel #4
0
    def init_fn(self):
        self.train_ds = MixedDataset(self.options,
                                     ignore_3d=self.options.ignore_3d,
                                     is_train=True)

        self.model = hmr(config.SMPL_MEAN_PARAMS,
                         pretrained=True).to(self.device)

        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=self.options.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=0.95)

        self.smpl = SMPL(config.SMPL_MODEL_DIR,
                         batch_size=self.options.batch_size,
                         create_transl=False).to(self.device)

        # consistency loss
        self.criterion_consistency_contrastive = NTXent(
            tau=self.options.tau, kernel=self.options.kernel).to(self.device)
        self.criterion_consistency_mse = nn.MSELoss().to(self.device)
        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.models_dict = {'model': self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}
        self.focal_length = constants.FOCAL_LENGTH

        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

        # Create renderer
        self.renderer = Renderer(focal_length=self.focal_length,
                                 img_res=self.options.img_res,
                                 faces=self.smpl.faces)

        # Create input image flag
        self.input_img = self.options.input_img

        # initialize queue
        self.feat_queue = FeatQueue(max_queue_size=self.options.max_queue_size)
Beispiel #5
0
    else:
        center, scale = bbox_from_pkl(bbox_file)
    img = crop(img, center, scale, (input_res, input_res))
    img = img.astype(np.float32) / 255.
    img = torch.from_numpy(img).permute(2, 0, 1)
    norm_img = normalize_img(img.clone())[None]
    return img, norm_img


if __name__ == '__main__':
    args = parser.parse_args()
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Load trained model
    model = hmr(config.SMPL_MEAN_PARAMS).to(device)
    checkpoint = torch.load(args.trained_model)
    model.load_state_dict(checkpoint['model'], strict=False)
    smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=1,
                create_transl=False).to(device)
    model.eval()
    # Generate rendered image
    renderer = Renderer(focal_length=constants.FOCAL_LENGTH,
                        img_res=constants.IMG_RES,
                        faces=smpl.faces)
    # Processs the image and predict the parameters
    img, norm_img = process_image(args.test_image,
                                  args.bbox,
                                  input_res=constants.IMG_RES)
    with torch.no_grad():
        pred_rotmat, pred_betas, pred_camera = model(norm_img.to(device))
Beispiel #6
0
    def init_fn(self):
        if self.options.rank == 0:
            self.summary_writer.add_text('command_args', print_args())

        if self.options.regressor == 'hmr':
            # HMR/SPIN model
            self.model = hmr(path_config.SMPL_MEAN_PARAMS, pretrained=True)
            self.smpl = SMPL(path_config.SMPL_MODEL_DIR,
                             batch_size=cfg.TRAIN.BATCH_SIZE,
                             create_transl=False).to(self.device)
        elif self.options.regressor == 'pymaf_net':
            # PyMAF model
            self.model = pymaf_net(path_config.SMPL_MEAN_PARAMS,
                                   pretrained=True)
            self.smpl = self.model.regressor[0].smpl

        if self.options.distributed:
            # For multiprocessing distributed, DistributedDataParallel constructor
            # should always set the single device scope, otherwise,
            # DistributedDataParallel will use all available devices.
            if self.options.gpu is not None:
                torch.cuda.set_device(self.options.gpu)
                self.model.cuda(self.options.gpu)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs we have
                self.options.batch_size = int(self.options.batch_size /
                                              self.options.ngpus_per_node)
                self.options.workers = int(
                    (self.options.workers + self.options.ngpus_per_node - 1) /
                    self.options.ngpus_per_node)
                self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.model)
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[self.options.gpu],
                    output_device=self.options.gpu,
                    find_unused_parameters=True)
            else:
                self.model.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model, find_unused_parameters=True)
            self.models_dict = {'model': self.model.module}
        else:
            self.model = self.model.to(self.device)
            self.models_dict = {'model': self.model}

        cudnn.benchmark = True

        # Per-vertex loss on the shape
        self.criterion_shape = nn.L1Loss().to(self.device)
        # Keypoint (2D and 3D) loss
        # No reduction because confidence weighting needs to be applied
        self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss().to(self.device)
        self.focal_length = constants.FOCAL_LENGTH

        if self.options.pretrained_checkpoint is not None:
            self.load_pretrained(
                checkpoint_file=self.options.pretrained_checkpoint)

        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=cfg.SOLVER.BASE_LR,
                                          weight_decay=0)

        self.optimizers_dict = {'optimizer': self.optimizer}

        if self.options.single_dataset:
            self.train_ds = BaseDataset(self.options,
                                        self.options.single_dataname,
                                        is_train=True)
        else:
            self.train_ds = MixedDataset(self.options, is_train=True)

        self.valid_ds = BaseDataset(self.options,
                                    self.options.eval_dataset,
                                    is_train=False)

        if self.options.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                self.train_ds)
            val_sampler = None
        else:
            train_sampler = None
            val_sampler = None

        self.train_data_loader = DataLoader(self.train_ds,
                                            batch_size=self.options.batch_size,
                                            num_workers=self.options.workers,
                                            pin_memory=cfg.TRAIN.PIN_MEMORY,
                                            shuffle=(train_sampler is None),
                                            sampler=train_sampler)

        self.valid_loader = DataLoader(dataset=self.valid_ds,
                                       batch_size=cfg.TEST.BATCH_SIZE,
                                       shuffle=False,
                                       num_workers=cfg.TRAIN.NUM_WORKERS,
                                       pin_memory=cfg.TRAIN.PIN_MEMORY,
                                       sampler=val_sampler)

        # Load dictionary of fits
        self.fits_dict = FitsDict(self.options, self.train_ds)
        self.evaluation_accumulators = dict.fromkeys([
            'pred_j3d', 'target_j3d', 'target_theta', 'pred_verts',
            'target_verts'
        ])

        # Create renderer
        try:
            self.renderer = OpenDRenderer()
        except:
            print('No renderer for visualization.')
            self.renderer = None

        if cfg.MODEL.PyMAF.AUX_SUPV_ON:
            self.iuv_maker = IUV_Renderer(
                output_size=cfg.MODEL.PyMAF.DP_HEATMAP_SIZE)

        self.decay_steps_ind = 1
        self.decay_epochs_ind = 1
Beispiel #7
0
        params =['--checkpoint','/home/hjoo/Dropbox (Facebook)/spinouput/10-31-50173-spin_all/checkpoints/2019_11_01-22_18_03.pt']      #wCOCO3D only early   first try
        params =['--checkpoint','/home/hjoo/Dropbox (Facebook)/spinouput/11-04-59961-filShp3_ours_coco3d_all-4030/checkpoints/2019_11_04-18_55_04-best-58.5394948720932.pt']
        params =['--checkpoint','data/model_checkpoint.pt']
        
        params =['--checkpoint','/home/hjoo/Dropbox (Facebook)/spinouput/11-06-42861-upper0_2_ours_lc3d_all-8935/checkpoints/2019_11_06-13_05_50-best-55.38778007030487.pt']
        params =['--checkpoint','/home/hjoo/Dropbox (Facebook)/spinouput/11-06-42861-upper0_2_ours_lc3d_all-8935/checkpoints/2019_11_06-13_05_50-best-55.38778007030487.pt']
        params =['--checkpoint','/home/hjoo/Dropbox (Facebook)/spinouput/11-07-46582-w_upper0_2_ours_lc3d_all-8143/checkpoints/2019_11_07-17_32_54-best-55.422715842723846.pt']
        
        # params +=['--num_workers',0]

        args = parser.parse_args(params)
        args.batch_size =128
        args.num_workers =0
        

    model = hmr(config.SMPL_MEAN_PARAMS)
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model'], strict=False)
    model.cuda()
    model.eval()

    # Setup evaluation dataset
    # dataset = BaseDataset(None, '3dpw', is_train=False, bMiniTest=False)
    dataset = BaseDataset(None, '3dpw', is_train=False, bMiniTest=False, bEnforceUpperOnly=False)
    
    # Run evaluation
    # result_file_name = '/run/media/hjoo/disk/data/cocoPose3D_amt/0_SPIN/result_3dpw_urs_11_04_59961_4030.pkl'
    result_file_name = '/run/media/hjoo/disk/data/cocoPose3D_amt/0_SPIN/spin_11-06-42861-upper0_2_ours_lc3d_all-8935.pkl'
    run_evaluation(model, '3dpw',dataset , result_file_name,
                   batch_size=args.batch_size,
                   shuffle=args.shuffle,
Beispiel #8
0
        full_arch_name = full_arch_name[:8] + '...'
    print(
        '| ' + full_arch_name + ' ' +
        ' '.join(['| {:.3f}'.format(value) for value in values]) +
         ' |'
    )


if __name__ == '__main__':
    args = parser.parse_args()
    parse_args(args)

    if args.regressor == 'pymaf_net':
        model = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True)
    if args.regressor == 'hmr':
        model = hmr(path_config.SMPL_MEAN_PARAMS)

    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint)
        model.load_state_dict(checkpoint['model'], strict=True)

    model.eval()

    dataset = COCODataset(None, args.dataset, 'val2014', is_train=False)

    # Run evaluation
    args.result_file = None
    run_evaluation(model, args.dataset, dataset, args.result_file,
                    batch_size=args.batch_size,
                    shuffle=args.shuffle,
                    log_freq=args.log_freq, options=args)
Beispiel #9
0
def main(args):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    if args.image_folder is None:
        video_file = args.vid_file

        # ========= [Optional] download the youtube video ========= #
        if video_file.startswith('https://www.youtube.com'):
            print(f'Donwloading YouTube video \"{video_file}\"')
            video_file = download_youtube_clip(video_file, '/tmp')

            if video_file is None:
                exit('Youtube url is not valid!')

            print(f'YouTube Video has been downloaded to {video_file}...')

        if not os.path.isfile(video_file):
            exit(f'Input video \"{video_file}\" does not exist!')

        output_path = os.path.join(
            args.output_folder,
            os.path.basename(video_file).replace('.mp4', ''))

        image_folder, num_frames, img_shape = video_to_images(video_file,
                                                              return_info=True)
    else:
        image_folder = args.image_folder
        num_frames = len(os.listdir(image_folder))
        img_shape = cv2.imread(
            osp.join(image_folder,
                     os.listdir(image_folder)[0])).shape

        output_path = os.path.join(args.output_folder,
                                   osp.split(image_folder)[-1])

    os.makedirs(output_path, exist_ok=True)

    print(f'Input video number of frames {num_frames}')
    if not args.image_based:
        orig_height, orig_width = img_shape[:2]

    total_time = time.time()

    # ========= Run tracking ========= #
    bbox_scale = 1.0
    if args.use_gt:
        with open(args.anno_file) as f:
            tracking_anno = json.load(f)
        tracking_results = {}
        for tracklet in tracking_anno:
            track_id = tracklet['idx']
            frames = tracklet['frames']
            f_id = []
            bbox = []
            for f in frames:
                f_id.append(f['frame_id'])
                x_tl, y_tl = f['rect']['tl']['x'] * orig_width, f['rect'][
                    'tl']['y'] * orig_height
                x_br, y_br = f['rect']['br']['x'] * orig_width, f['rect'][
                    'br']['y'] * orig_height

                x_c, y_c = (x_br + x_tl) / 2., (y_br + y_tl) / 2.
                w, h = x_br - x_tl, y_br - y_tl
                wh_max = max(w, h)
                x_tl, y_tl = x_c - wh_max / 2., y_c - wh_max / 2.

                bbox.append(np.array([x_c, y_c, wh_max, wh_max]))
            f_id = np.array(f_id)
            bbox = np.array(bbox)
            tracking_results[track_id] = {'frames': f_id, 'bbox': bbox}
    else:
        # run multi object tracker
        mot = MPT(
            device=device,
            batch_size=args.tracker_batch_size,
            display=args.display,
            detector_type=args.detector,
            output_format='dict',
            yolo_img_size=args.yolo_img_size,
        )
        tracking_results = mot(image_folder)

    # remove tracklets if num_frames is less than MIN_NUM_FRAMES
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]

    # ========= Define model ========= #
    if args.regressor == 'pymaf_net':
        model = pymaf_net(path_config.SMPL_MEAN_PARAMS,
                          pretrained=True).to(device)
    elif args.regressor == 'hmr':
        model = hmr(path_config.SMPL_MEAN_PARAMS).to(device)

    # ========= Load pretrained weights ========= #
    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint)
        model.load_state_dict(checkpoint['model'], strict=True)

    model.eval()
    print(f'Loaded pretrained weights from \"{args.checkpoint}\"')

    # ========= Run pred on each person ========= #
    if args.recon_result_file:
        pred_results = joblib.load(args.recon_result_file)
        print('Loaded results from ' + args.recon_result_file)
    else:
        if args.pre_load_imgs:
            image_file_names = [
                osp.join(image_folder, x) for x in os.listdir(image_folder)
                if x.endswith('.png') or x.endswith('.jpg')
            ]
            image_file_names = sorted(image_file_names)
            image_file_names = np.array(image_file_names)
            pre_load_imgs = []
            for file_name in image_file_names:
                pre_load_imgs.append(
                    cv2.cvtColor(cv2.imread(file_name), cv2.COLOR_BGR2RGB))
            pre_load_imgs = np.array(pre_load_imgs)
            print('image_file_names', pre_load_imgs.shape)
        else:
            image_file_names = None
        print(f'Running reconstruction on each tracklet...')
        pred_time = time.time()
        pred_results = {}
        for person_id in tqdm(list(tracking_results.keys())):
            bboxes = joints2d = None

            if args.tracking_method == 'bbox':
                bboxes = tracking_results[person_id]['bbox']
            elif args.tracking_method == 'pose':
                joints2d = tracking_results[person_id]['joints2d']

            frames = tracking_results[person_id]['frames']

            if args.pre_load_imgs:
                print('image_file_names frames', pre_load_imgs[frames].shape)
                dataset = Inference(image_folder=image_folder,
                                    frames=frames,
                                    bboxes=bboxes,
                                    joints2d=joints2d,
                                    scale=bbox_scale,
                                    pre_load_imgs=pre_load_imgs[frames])
            else:
                dataset = Inference(
                    image_folder=image_folder,
                    frames=frames,
                    bboxes=bboxes,
                    joints2d=joints2d,
                    scale=bbox_scale,
                )

            if args.image_based:
                img_shape = cv2.imread(
                    osp.join(image_folder,
                             os.listdir(image_folder)[frames[0]])).shape
                orig_height, orig_width = img_shape[:2]

            bboxes = dataset.bboxes
            frames = dataset.frames
            has_keypoints = True if joints2d is not None else False

            dataloader = DataLoader(dataset,
                                    batch_size=args.model_batch_size,
                                    num_workers=16)

            with torch.no_grad():

                pred_cam, pred_verts, pred_pose, pred_betas, pred_joints3d, norm_joints2d = [], [], [], [], [], []

                for batch in dataloader:
                    if has_keypoints:
                        batch, nj2d = batch
                        norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))

                    # batch = batch.unsqueeze(0)
                    batch = batch.to(device)

                    # batch_size, seqlen = batch.shape[:2]
                    batch_size = batch.shape[0]
                    seqlen = 1
                    preds_dict, _ = model(batch)

                    output = preds_dict['smpl_out'][-1]

                    pred_cam.append(output['theta'][:, :3].reshape(
                        batch_size * seqlen, -1))
                    pred_verts.append(output['verts'].reshape(
                        batch_size * seqlen, -1, 3))
                    pred_pose.append(output['theta'][:, 3:75].reshape(
                        batch_size * seqlen, -1))
                    pred_betas.append(output['theta'][:, 75:].reshape(
                        batch_size * seqlen, -1))
                    pred_joints3d.append(output['kp_3d'].reshape(
                        batch_size * seqlen, -1, 3))

                pred_cam = torch.cat(pred_cam, dim=0)
                pred_verts = torch.cat(pred_verts, dim=0)
                pred_pose = torch.cat(pred_pose, dim=0)
                pred_betas = torch.cat(pred_betas, dim=0)
                pred_joints3d = torch.cat(pred_joints3d, dim=0)

                del batch

            # ========= Save results to a pickle file ========= #
            pred_cam = pred_cam.cpu().numpy()
            pred_verts = pred_verts.cpu().numpy()
            pred_pose = pred_pose.cpu().numpy()
            pred_betas = pred_betas.cpu().numpy()
            pred_joints3d = pred_joints3d.cpu().numpy()

            orig_cam = convert_crop_cam_to_orig_img(cam=pred_cam,
                                                    bbox=bboxes,
                                                    img_width=orig_width,
                                                    img_height=orig_height)

            output_dict = {
                'pred_cam': pred_cam,
                'orig_cam': orig_cam,
                'verts': pred_verts,
                'pose': pred_pose,
                'betas': pred_betas,
                'joints3d': pred_joints3d,
                'joints2d': joints2d,
                'bboxes': bboxes,
                'frame_ids': frames,
            }

            pred_results[person_id] = output_dict

        del model

        end = time.time()
        fps = num_frames / (end - pred_time)

        print(f'FPS: {fps:.2f}')
        total_time = time.time() - total_time
        print(
            f'Total time spent: {total_time:.2f} seconds (including model loading time).'
        )
        print(
            f'Total FPS (including model loading time): {num_frames / total_time:.2f}.'
        )

        print(
            f'Saving output results to \"{os.path.join(output_path, "_output.pkl")}\".'
        )

        joblib.dump(pred_results, os.path.join(output_path, "_output.pkl"))

    if not args.no_render:
        # ========= Render results as a single video ========= #
        if args.use_opendr:
            renderer = OpenDRenderer(resolution=(orig_height, orig_width))
        else:
            renderer = PyRenderer(resolution=(orig_width, orig_height))

        output_img_folder = os.path.join(
            output_path,
            osp.split(image_folder)[-1] + '_output')
        os.makedirs(output_img_folder, exist_ok=True)

        print(f'Rendering output video, writing frames to {output_img_folder}')

        # prepare results for rendering
        frame_results = prepare_rendering_results(pred_results, num_frames)

        image_file_names = sorted([
            os.path.join(image_folder, x) for x in os.listdir(image_folder)
            if x.endswith('.png') or x.endswith('.jpg')
        ])

        if args.regressor == 'hmr':
            color_type = 'pink'
        elif cfg.MODEL.PyMAF.N_ITER == 0 and cfg.MODEL.PyMAF.AUX_SUPV_ON == False:
            color_type = 'neutral'
        else:
            color_type = 'purple'

        for frame_idx in tqdm(range(len(image_file_names))):
            img_fname = image_file_names[frame_idx]
            img = cv2.imread(img_fname)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            if args.render_ratio != 1:
                img = resize(img, (int(img.shape[0] * args.render_ratio),
                                   int(img.shape[1] * args.render_ratio)),
                             anti_aliasing=True)
                img = (img * 255).astype(np.uint8)

            raw_img = img.copy()

            # if args.sideview:
            #     side_img = np.zeros_like(img)

            if args.empty_bg:
                empty_img = np.zeros_like(img)

            for person_id, person_data in frame_results[frame_idx].items():
                frame_verts = person_data['verts']
                frame_cam = person_data['cam']

                mesh_filename = None

                if args.save_obj:
                    mesh_folder = os.path.join(output_path, 'meshes',
                                               f'{person_id:04d}')
                    os.makedirs(mesh_folder, exist_ok=True)
                    mesh_filename = os.path.join(mesh_folder,
                                                 f'{frame_idx:06d}.obj')

                if args.empty_bg:
                    img, empty_img = renderer(frame_verts[None, :, :] if
                                              args.use_opendr else frame_verts,
                                              img=[img, empty_img],
                                              cam=frame_cam,
                                              color_type=color_type,
                                              mesh_filename=mesh_filename)
                else:
                    img = renderer(frame_verts[None, :, :]
                                   if args.use_opendr else frame_verts,
                                   img=img,
                                   cam=frame_cam,
                                   color_type=color_type,
                                   mesh_filename=mesh_filename)

                # if args.sideview:
                #     side_img = renderer(
                #         frame_verts,
                #         img=side_img,
                #         cam=frame_cam,
                #         color_type=color_type,
                #         angle=270,
                #         axis=[0,1,0],
                #     )

            if args.with_raw:
                img = np.concatenate([raw_img, img], axis=1)

            if args.empty_bg:
                img = np.concatenate([img, empty_img], axis=1)

            # if args.sideview:
            #     img = np.concatenate([img, side_img], axis=1)

            # cv2.imwrite(os.path.join(output_img_folder, f'{frame_idx:06d}.png'), img)
            if args.image_based:
                imsave(
                    os.path.join(output_img_folder,
                                 osp.split(img_fname)[-1][:-4] + '.png'), img)
            else:
                imsave(os.path.join(output_img_folder, f'{frame_idx:06d}.png'),
                       img)

            if args.display:
                cv2.imshow('Video', img)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

        if args.display:
            cv2.destroyAllWindows()

        # ========= Save rendered video ========= #
        vid_name = osp.split(
            image_folder
        )[-1] if args.image_folder is not None else os.path.basename(
            video_file)
        save_name = f'{vid_name.replace(".mp4", "")}_result.mp4'
        save_name = os.path.join(output_path, save_name)
        if not args.image_based:
            print(f'Saving result video to {save_name}')
            images_to_video(img_folder=output_img_folder,
                            output_vid_file=save_name)
        # shutil.rmtree(output_img_folder)

    # shutil.rmtree(image_folder)
    print('================= END =================')