def main():
    args = get_args()
    torch.set_grad_enabled(False)

    if args.network == "mobile0.25":
        cfg = cfg_mnet
    elif args.network == "resnet50":
        cfg = cfg_re50
    else:
        raise NotImplementedError(
            f"Only mobile0.25 and resnet50 are suppoted.")

    # net and model
    net = RetinaFace(cfg=cfg, phase="test")
    net = load_model(net, args.trained_model, args.cpu)
    net.eval()
    if args.fp16:
        net = net.half()

    print("Finished loading model!")
    cudnn.benchmark = True
    device = torch.device("cpu" if args.cpu else "cuda")
    net = net.to(device)

    file_paths = sorted(args.input_path.rglob("*.jpg"))

    if args.num_gpu is not None:
        start, end = split_array(len(file_paths), args.num_gpu, args.gpu_id)
        file_paths = file_paths[start:end]

    output_path = args.output_path

    if args.save_boxes:
        output_label_path = output_path / "labels"
        output_label_path.mkdir(exist_ok=True, parents=True)

    if args.save_crops:
        output_image_path = output_path / "images"
        output_image_path.mkdir(exist_ok=True, parents=True)

    transform = albu.Compose([
        albu.Normalize(
            p=1, mean=(104, 117, 123), std=(1.0, 1.0, 1.0), max_pixel_value=1)
    ],
                             p=1)

    test_loader = DataLoader(
        InferenceDataset(file_paths, args.origin_size, transform=transform),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    with torch.no_grad():
        for raw_input in tqdm(test_loader):
            torched_images = raw_input["torched_image"]

            if args.fp16:
                torched_images = torched_images.half()

            resizes = raw_input["resize"]
            image_paths = Path(raw_input["image_path"])
            raw_images = raw_input["raw_image"]

            labels = []

            if (args.batch_size == 1 and args.save_boxes
                    and (output_label_path /
                         f"{Path(image_paths[0]).stem}.json").exists()):
                continue

            loc, conf, land = net(torched_images.to(device))  # forward pass

            batch_size = torched_images.shape[0]

            image_height, image_width = torched_images.shape[2:]

            scale1 = torch.Tensor([
                image_width,
                image_height,
                image_width,
                image_height,
                image_width,
                image_height,
                image_width,
                image_height,
                image_width,
                image_height,
            ])

            scale1 = scale1.to(device)

            scale = torch.Tensor(
                [image_width, image_height, image_width, image_height])
            scale = scale.to(device)

            priorbox = PriorBox(cfg, image_size=(image_height, image_width))
            priors = priorbox.forward()
            priors = priors.to(device)
            prior_data = priors.data

            for batch_id in range(batch_size):
                image_path = image_paths[batch_id]
                file_id = Path(image_path).stem
                raw_image = raw_images[batch_id]

                resize = resizes[batch_id].float()

                boxes = decode(loc.data[batch_id], prior_data, cfg["variance"])

                boxes *= scale / resize
                scores = conf[batch_id][:, 1]

                landmarks = decode_landm(land.data[batch_id], prior_data,
                                         cfg["variance"])
                landmarks *= scale1 / resize

                # ignore low scores
                valid_index = torch.where(
                    scores > args.confidence_threshold)[0]
                boxes = boxes[valid_index]
                landmarks = landmarks[valid_index]
                scores = scores[valid_index]

                order = scores.argsort(descending=True)

                boxes = boxes[order]
                landmarks = landmarks[order]
                scores = scores[order]

                # do NMS
                keep = nms(boxes, scores, args.nms_threshold)
                boxes = boxes[keep, :].int()

                landmarks = landmarks[keep].int()

                if boxes.shape[0] == 0:
                    continue

                scores = scores[keep].cpu().numpy().astype(np.float64)

                for crop_id, bbox in enumerate(boxes):

                    bbox = bbox.cpu().numpy()

                    labels += [{
                        "crop_id": crop_id,
                        "bbox": bbox.tolist(),
                        "score": scores[crop_id],
                        "landmarks": landmarks[crop_id].tolist(),
                    }]

                    if args.save_crops:
                        x_min, y_min, x_max, y_max = bbox

                        x_min = max(0, x_min)
                        y_min = max(0, y_min)

                        crop = raw_image[y_min:y_max,
                                         x_min:x_max].cpu().numpy()

                        target_folder = output_image_path / f"{file_id}"
                        target_folder.mkdir(exist_ok=True, parents=True)

                        crop_file_path = target_folder / f"{file_id}_{crop_id}.jpg"

                        if crop_file_path.exists():
                            continue

                        cv2.imwrite(
                            str(crop_file_path),
                            cv2.cvtColor(crop, cv2.COLOR_BGR2RGB),
                            [int(cv2.IMWRITE_JPEG_QUALITY), 90],
                        )

                if args.save_boxes:
                    result = {
                        "file_path": image_path,
                        "file_id": file_id,
                        "bboxes": labels,
                    }

                    with open(output_label_path / f"{file_id}.json", "w") as f:
                        json.dump(result, f, indent=2)
def process_video_files(
    network: str,
    trained_model: str,
    decode_gpu: bool,
    is_fp16: bool,
    file_paths: list,
    num_gpu: Optional[int],
    gpu_id: int,
    output_path: Path,
    is_save_boxes: bool,
    is_save_crops: bool,
    num_frames: int,
    resize_coeff: Optional[Tuple],
    confidence_threshold: float,
    num_workers: int,
    nms_threshold: float,
    batch_size: int,
    resize_scale: float,
    min_size: int,
    keep_top_k: int,
) -> None:
    torch.set_grad_enabled(False)

    if network == "mobile0.25":
        cfg = cfg_mnet_test
    elif network == "resnet50":
        cfg = cfg_re50_test
    else:
        raise NotImplementedError(
            f"Only mobile0.25 and resnet50 are suppoted, but we got {network}")

    if min_size < 0:
        raise ValueError(
            f"Min size should be positive, but we got {min_size}.")

    # net and model
    net = RetinaFace(cfg=cfg, phase="test")
    net = load_model(net, trained_model, load_to_cpu=False)
    net.eval()

    if is_fp16:
        net = net.half()

    device = torch.device("cuda")
    net.to(device)

    print("Finished loading model!")
    cudnn.benchmark = True

    transform = albu.Compose([
        albu.Normalize(
            p=1, mean=(104, 117, 123), std=(1.0, 1.0, 1.0), max_pixel_value=1)
    ],
                             p=1)

    if num_gpu is not None:
        start, end = split_array(len(file_paths), num_gpu, gpu_id)
        file_paths = file_paths[start:end]

    with torch.no_grad():
        func = partial(get_frames,
                       num_frames=num_frames,
                       resize_coeff=resize_coeff,
                       transform=transform,
                       decode_gpu=decode_gpu)

        with torch.no_grad():
            with concurrent.futures.ProcessPoolExecutor(
                    num_workers) as executor:
                for result in tqdm(executor.map(func, file_paths),
                                   total=len(file_paths),
                                   leave=False,
                                   desc="Loading data files"):
                    if len(result) != 0:
                        result["is_fp16"] = is_fp16
                        result["device"] = device
                        result["batch_size"] = batch_size
                        result["cfg"] = cfg
                        result["nms_threshold"] = nms_threshold
                        result["confidence_threshold"] = confidence_threshold
                        result["is_save_crops"] = is_save_crops
                        result["is_save_boxes"] = is_save_boxes
                        result["output_path"] = output_path
                        result["net"] = net
                        result["min_size"] = min_size
                        result["resize_scale"] = resize_scale
                        result["keep_top_k"] = keep_top_k

                        process_frames(**result)
def main():
    args = get_args()
    torch.set_grad_enabled(False)

    if args.network == "mobile0.25":
        cfg = cfg_mnet
    elif args.network == "resnet50":
        cfg = cfg_re50
    else:
        raise NotImplementedError(f"Only mobile0.25 and resnet50 are suppoted.")

    # net and model
    net = RetinaFace(cfg=cfg, phase="test")
    net = load_model(net, args.trained_model, args.cpu)
    net.eval()
    if args.fp16:
        net = net.half()

    print("Finished loading model!")
    cudnn.benchmark = True
    device = torch.device("cpu" if args.cpu else "cuda")
    net = net.to(device)

    file_paths = sorted(args.input_path.rglob("*.mp4"))[: args.num_videos]

    if args.num_gpu is not None:
        start, end = split_array(len(file_paths), args.num_gpu, args.gpu_id)
        file_paths = file_paths[start:end]

    output_path = args.output_path

    if args.save_boxes:
        output_label_path = output_path / "labels"
        output_label_path.mkdir(exist_ok=True, parents=True)

    if args.save_crops:
        output_image_path = output_path / "images"
        output_image_path.mkdir(exist_ok=True, parents=True)

    if args.video_decoder == "cpu":
        decode_device = cpu(0)
    elif args.video_decoder == "gpu":
        decode_device = gpu(0)
    else:
        raise NotImplementedError(f"Only CPU and GPU devices are supported by decard, but got {args.video_decoder}")

    transform = albu.Compose([albu.Normalize(p=1, mean=(104, 117, 123), std=(1.0, 1.0, 1.0), max_pixel_value=1)], p=1)

    with torch.no_grad():
        for video_path in tqdm(file_paths):
            labels = []
            video_id = video_path.stem

            with video_reader(str(video_path), ctx=decode_device) as video:
                len_video = len(video)

                if args.num_frames is None or args.num_frames == 1:
                    frame_ids = list(range(args.num_frames))
                elif args.num_frames > 1:
                    if len_video < args.num_frames:
                        step = 1
                    else:
                        step = int(len_video / args.num_frames)

                    frame_ids = list(range(0, len_video, step))[: args.num_frames]
                else:
                    raise ValueError(f"Expect None or integer > 1 for args.num_frames, but got {args.num_frames}")

                frames = video.get_batch(frame_ids)

                if args.video_decoder == "cpu":
                    frames = frames.asnumpy()
                elif args.video_decoder == "gpu":
                    frames = dlpack.from_dlpack(frames.to_dlpack())

                if args.video_decoder == "gpu":
                    del video
                    torch.cuda.empty_cache()

                    gc.collect()

            num_frames = len(frames)

            image_height = frames.shape[1]
            image_width = frames.shape[2]

            scale1 = torch.Tensor(
                [
                    image_width,
                    image_height,
                    image_width,
                    image_height,
                    image_width,
                    image_height,
                    image_width,
                    image_height,
                    image_width,
                    image_height,
                ]
            )

            scale1 = scale1.to(device)

            scale = torch.Tensor([image_width, image_height, image_width, image_height])
            scale = scale.to(device)

            priorbox = PriorBox(cfg, image_size=(image_height, image_width))
            priors = priorbox.forward()
            priors = priors.to(device)
            prior_data = priors.data

            if args.resize_coeff is not None:
                target_size = min(args.resize_coeff)
                max_size = max(args.resize_coeff)

                image_height = frames.shape[1]
                image_width = frames.shape[2]

                image_size_min = min([image_width, image_height])
                image_size_max = max([image_width, image_height])

                resize = float(target_size) / float(image_size_min)
                if np.round(resize * image_size_max) > max_size:
                    resize = float(max_size) / float(image_size_max)
            else:
                resize = 1

            for pred_id in range(num_frames):
                frame = frames[pred_id]

                torched_image = prepare_image(frame, transform, args.video_decoder).to(device)

                if args.fp16:
                    torched_image = torched_image.half()

                loc, conf, land = net(torched_image)  # forward pass

                frame_id = frame_ids[pred_id]

                boxes = decode(loc.data[0], prior_data, cfg["variance"])

                boxes *= scale / resize

                boxes = boxes.cpu().numpy()
                scores = conf[0].data.cpu().numpy()[:, 1]

                landmarks = decode_landm(land.data[0], prior_data, cfg["variance"])

                landmarks *= scale1 / resize
                landmarks = landmarks.cpu().numpy()

                # ignore low scores
                valid_index = np.where(scores > args.confidence_threshold)[0]
                boxes = boxes[valid_index]
                landmarks = landmarks[valid_index]
                scores = scores[valid_index]

                # keep top-K before NMS
                order = scores.argsort()[::-1]
                # order = scores.argsort()[::-1][:args.top_k]
                boxes = boxes[order]
                landmarks = landmarks[order]
                scores = scores[order]

                # do NMS
                detection = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
                keep = py_cpu_nms(detection, args.nms_threshold)
                # keep = nms(detection, args.nms_threshold,force_cpu=args.cpu)

                # x_min, y_min, x_max, y_max, score
                detection = detection[keep, :]

                landmarks = landmarks[keep].astype(int)

                if detection.shape[0] == 0:
                    continue

                bboxes = detection[:, :4].astype(int)
                confidence = detection[:, 4].astype(np.float64)

                for crop_id in range(len(detection)):

                    bbox = bboxes[crop_id]

                    labels += [
                        {
                            "frame_id": int(frame_id),
                            "crop_id": crop_id,
                            "bbox": bbox.tolist(),
                            "score": confidence[crop_id],
                            "landmarks": landmarks[crop_id].tolist(),
                        }
                    ]

                    if args.save_crops:
                        x_min, y_min, x_max, y_max = bbox

                        x_min = max(0, x_min)
                        y_min = max(0, y_min)

                        crop = frame[y_min:y_max, x_min:x_max]

                        target_folder = output_image_path / f"{video_id}"
                        target_folder.mkdir(exist_ok=True, parents=True)

                        crop_file_path = target_folder / f"{frame_id}_{crop_id}.jpg"

                        if crop_file_path.exists():
                            continue

                        cv2.imwrite(
                            str(crop_file_path),
                            cv2.cvtColor(crop, cv2.COLOR_BGR2RGB),
                            [int(cv2.IMWRITE_JPEG_QUALITY), 90],
                        )

                if args.save_boxes:
                    result = {
                        "file_path": str(video_path),
                        "file_id": video_id,
                        "bboxes": labels,
                    }

                    with open(output_label_path / f"{video_id}.json", "w") as f:
                        json.dump(result, f, indent=2)