Пример #1
0
    def test_step(self, batch: Dict[str, torch.Tensor],
                  batch_idx: int) -> None:
        torched_images = batch["torched_image"]
        resizes = batch["resize"]
        image_paths = batch["image_path"]
        raw_images = batch["raw_image"]

        labels: List[Dict[str, Any]] = []

        loc, conf, land = self.model(torched_images)
        conf = F.softmax(conf, dim=-1)

        batch_size = torched_images.shape[0]

        image_height, image_width = torched_images.shape[2:]

        scale1 = torch.from_numpy(np.tile([image_width, image_height],
                                          5)).to(self.device)
        scale = torch.from_numpy(np.tile([image_width, image_height],
                                         2)).to(self.device)

        priors = object_from_dict(hparams["prior_box"],
                                  image_size=(image_height,
                                              image_width)).to(loc.device)

        for batch_id in range(batch_size):
            image_path = image_paths[batch_id]
            file_id = Path(str(image_path)).stem
            raw_image = raw_images[batch_id]

            resize = resizes[batch_id].float()

            boxes = decode(loc.data[batch_id], priors,
                           hparams["test_parameters"]["variance"])

            boxes *= scale / resize
            scores = conf[batch_id][:, 1]

            landmarks = decode_landm(land.data[batch_id], priors,
                                     hparams["test_parameters"]["variance"])
            landmarks *= scale1 / resize

            # ignore low scores
            valid_index = torch.where(
                scores > self.hparams["confidence_threshold"])[0]
            boxes = boxes[valid_index]
            landmarks = landmarks[valid_index]
            scores = scores[valid_index]

            order = scores.argsort(descending=True)

            boxes = boxes[order]
            landmarks = landmarks[order]
            scores = scores[order]

            # do NMS
            keep = nms(boxes, scores, self.hparams["nms_threshold"])
            boxes = boxes[keep, :].int()

            if boxes.shape[0] == 0:
                continue

            landmarks = landmarks[keep].int()
            scores = scores[keep].cpu().numpy().astype(np.float64)

            boxes = boxes[:self.hparams["keep_top_k"]]
            landmarks = landmarks[:self.hparams["keep_top_k"]]
            scores = scores[:self.hparams["keep_top_k"]]

            if self.hparams["visualize"]:
                vis_image = raw_image.cpu().numpy().copy()

                for crop_id, bbox in enumerate(boxes):
                    landms = landmarks[crop_id].cpu().numpy().reshape([5, 2])

                    colors = [(255, 0, 0), (128, 255, 0), (255, 178, 102),
                              (102, 128, 255), (0, 255, 255)]
                    for i, (x, y) in enumerate(landms):
                        vis_image = cv2.circle(vis_image, (x, y),
                                               radius=3,
                                               color=colors[i],
                                               thickness=3)

                    x_min, y_min, x_max, y_max = bbox.cpu().numpy()

                    x_min = np.clip(x_min, 0, x_max - 1)
                    y_min = np.clip(y_min, 0, y_max - 1)

                    vis_image = cv2.rectangle(vis_image, (x_min, y_min),
                                              (x_max, y_max),
                                              color=(0, 255, 0),
                                              thickness=2)

                    cv2.imwrite(str(self.output_vis_path / f"{file_id}.jpg"),
                                cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB))

            for crop_id, bbox in enumerate(boxes):
                bbox = bbox.cpu().numpy()

                labels += [{
                    "crop_id": crop_id,
                    "bbox": bbox.tolist(),
                    "score": scores[crop_id],
                    "landmarks": landmarks[crop_id].tolist(),
                }]

            result = {
                "file_path": image_path,
                "file_id": file_id,
                "bboxes": labels
            }

            with open(self.output_label_path / f"{file_id}.json", "w") as f:
                json.dump(result, f, indent=2)
Пример #2
0
def main():
    args = get_args()
    torch.distributed.init_process_group(backend="nccl")

    with open(args.config_path) as f:
        hparams = yaml.load(f, Loader=yaml.SafeLoader)

    hparams.update({
        "json_path":
        args.output_path,
        "visualize":
        args.visualize,
        "confidence_threshold":
        args.confidence_threshold,
        "nms_threshold":
        args.nms_threshold,
        "keep_top_k":
        args.keep_top_k,
        "local_rank":
        args.local_rank,
        "prior_box":
        object_from_dict(hparams["prior_box"],
                         image_size=[args.max_size, args.max_size]),
        "fp16":
        args.fp16,
        "folder_in_name":
        args.folder_in_name,
    })

    if args.visualize:
        output_vis_path = args.output_path / "viz"
        output_vis_path.mkdir(parents=True, exist_ok=True)
        hparams["output_vis_path"] = output_vis_path

    output_label_path = args.output_path / "labels"
    output_label_path.mkdir(parents=True, exist_ok=True)
    hparams["output_label_path"] = output_label_path

    device = torch.device("cuda", args.local_rank)

    model = object_from_dict(hparams["model"])
    model = model.to(device)

    if args.fp16:
        model = model.half()

    corrections: Dict[str, str] = {"model.": ""}
    checkpoint = load_checkpoint(file_path=args.weight_path,
                                 rename_in_layers=corrections)
    model.load_state_dict(checkpoint["state_dict"])

    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank)

    file_paths = []

    for regexp in ["*.jpg", "*.png", "*.jpeg", "*.JPG"]:
        file_paths += sorted([x for x in tqdm(args.input_path.rglob(regexp))])

    dataset = InferenceDataset(file_paths,
                               max_size=args.max_size,
                               transform=from_dict(hparams["test_aug"]))

    sampler = DistributedSampler(dataset, shuffle=False)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
        sampler=sampler,
    )

    predict(dataloader, model, hparams, device)
Пример #3
0
        "file_paths":
        sorted([x for x in args.input_path.rglob("*") if x.is_file()]),
        "json_path":
        args.output_path,
        "output_path":
        args.output_path,
        "visualize":
        args.visualize,
        "origin_size":
        args.origin_size,
        "max_size":
        args.max_size,
        "target_size":
        args.target_size,
        "confidence_threshold":
        args.confidence_threshold,
        "nms_threshold":
        args.nms_threshold,
        "keep_top_k":
        args.keep_top_k,
    })
    hparams["trainer"]["gpus"] = 1  # Right now we work only with one GPU

    model = InferenceModel(hparams, weight_path=args.weight_path)
    trainer = object_from_dict(
        hparams["trainer"],
        checkpoint_callback=object_from_dict(hparams["checkpoint_callback"]),
    )

    trainer.test(model)
Пример #4
0
def main():
    args = get_args()
    torch.set_grad_enabled(False)

    with open(args.config_path) as f:
        hparams = yaml.load(f, Loader=yaml.SafeLoader)

    device = torch.device("cpu" if args.cpu else "cuda")

    net = get_model(hparams, args.weights, args.fp16, device)

    file_paths = sorted(args.input_path.rglob("*.jpg"))

    output_path = args.output_path
    output_vis_path = output_path / "viz"
    output_label_path = output_path / "labels"
    output_image_path = output_path / "images"

    prepare_output_folders(output_vis_path, output_label_path,
                           output_image_path, args.save_boxes, args.save_crops,
                           args.visualize)

    transform = from_dict(hparams["test_aug"])

    test_loader = DataLoader(
        InferenceDataset(file_paths, args.origin_size, transform=transform),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    with torch.no_grad():
        for raw_input in tqdm(test_loader):
            torched_images = raw_input["torched_image"].type(net.dtype)

            resizes = raw_input["resize"]
            image_paths = raw_input["image_path"]
            raw_images = raw_input["raw_image"]

            labels = []

            if (args.batch_size == 1 and args.save_boxes
                    and (output_label_path /
                         f"{Path(image_paths[0]).stem}.json").exists()):
                continue

            loc, conf, land = net(torched_images.to(device))  # forward pass

            batch_size = torched_images.shape[0]

            image_height, image_width = torched_images.shape[2:]

            scale1 = torch.Tensor([
                image_width,
                image_height,
                image_width,
                image_height,
                image_width,
                image_height,
                image_width,
                image_height,
                image_width,
                image_height,
            ])

            scale1 = scale1.to(device)

            scale = torch.Tensor(
                [image_width, image_height, image_width, image_height])
            scale = scale.to(device)

            priors = object_from_dict(hparams["prior_box"],
                                      image_size=(image_height,
                                                  image_width)).to(loc.device)

            for batch_id in range(batch_size):
                image_path = image_paths[batch_id]
                file_id = Path(image_path).stem
                raw_image = raw_images[batch_id]

                resize = resizes[batch_id].float()

                boxes = decode(loc.data[batch_id], priors,
                               hparams["test_parameters"]["variance"])

                boxes *= scale / resize
                scores = conf[batch_id][:, 1]

                landmarks = decode_landm(
                    land.data[batch_id], priors,
                    hparams["test_parameters"]["variance"])
                landmarks *= scale1 / resize

                # ignore low scores
                valid_index = torch.where(
                    scores > args.confidence_threshold)[0]
                boxes = boxes[valid_index]
                landmarks = landmarks[valid_index]
                scores = scores[valid_index]

                order = scores.argsort(descending=True)

                boxes = boxes[order]
                landmarks = landmarks[order]
                scores = scores[order]

                # do NMS
                keep = nms(boxes, scores, args.nms_threshold)
                boxes = boxes[keep, :].int()

                landmarks = landmarks[keep].int()

                if boxes.shape[0] == 0:
                    continue

                scores = scores[keep].cpu().numpy().astype(np.float64)

                if args.visualize:
                    vis_image = raw_image.cpu().numpy().copy()

                for crop_id, bbox in enumerate(boxes):
                    bbox = bbox.cpu().numpy()

                    labels += [{
                        "crop_id": crop_id,
                        "bbox": bbox.tolist(),
                        "score": scores[crop_id],
                        "landmarks": landmarks[crop_id].tolist(),
                    }]

                    if args.save_crops:
                        x_min, y_min, x_max, y_max = bbox

                        x_min = np.clip(x_min, 0, x_max - 1)
                        y_min = np.clip(y_min, 0, y_max - 1)

                        crop = raw_image[y_min:y_max,
                                         x_min:x_max].cpu().numpy()

                        if args.visualize:
                            vis_image = cv2.rectangle(vis_image,
                                                      (x_min, y_min),
                                                      (x_max, y_max),
                                                      color=(255, 0, 0),
                                                      thickness=3)

                        target_folder = output_image_path / f"{file_id}"
                        target_folder.mkdir(exist_ok=True, parents=True)

                        crop_file_path = target_folder / f"{file_id}_{crop_id}.jpg"

                        if crop_file_path.exists():
                            continue

                        cv2.imwrite(str(crop_file_path),
                                    cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))

                if args.visualize:
                    cv2.imwrite(str(output_vis_path / f"{file_id}.jpg"),
                                cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB))

                if args.save_boxes:
                    result = {
                        "file_path": image_path,
                        "file_id": file_id,
                        "bboxes": labels,
                    }

                    with open(output_label_path / f"{file_id}.json", "w") as f:
                        json.dump(result, f, indent=2)