def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: torched_images = batch["torched_image"] resizes = batch["resize"] image_paths = batch["image_path"] raw_images = batch["raw_image"] labels: List[Dict[str, Any]] = [] loc, conf, land = self.model(torched_images) conf = F.softmax(conf, dim=-1) batch_size = torched_images.shape[0] image_height, image_width = torched_images.shape[2:] scale1 = torch.from_numpy(np.tile([image_width, image_height], 5)).to(self.device) scale = torch.from_numpy(np.tile([image_width, image_height], 2)).to(self.device) priors = object_from_dict(hparams["prior_box"], image_size=(image_height, image_width)).to(loc.device) for batch_id in range(batch_size): image_path = image_paths[batch_id] file_id = Path(str(image_path)).stem raw_image = raw_images[batch_id] resize = resizes[batch_id].float() boxes = decode(loc.data[batch_id], priors, hparams["test_parameters"]["variance"]) boxes *= scale / resize scores = conf[batch_id][:, 1] landmarks = decode_landm(land.data[batch_id], priors, hparams["test_parameters"]["variance"]) landmarks *= scale1 / resize # ignore low scores valid_index = torch.where( scores > self.hparams["confidence_threshold"])[0] boxes = boxes[valid_index] landmarks = landmarks[valid_index] scores = scores[valid_index] order = scores.argsort(descending=True) boxes = boxes[order] landmarks = landmarks[order] scores = scores[order] # do NMS keep = nms(boxes, scores, self.hparams["nms_threshold"]) boxes = boxes[keep, :].int() if boxes.shape[0] == 0: continue landmarks = landmarks[keep].int() scores = scores[keep].cpu().numpy().astype(np.float64) boxes = boxes[:self.hparams["keep_top_k"]] landmarks = landmarks[:self.hparams["keep_top_k"]] scores = scores[:self.hparams["keep_top_k"]] if self.hparams["visualize"]: vis_image = raw_image.cpu().numpy().copy() for crop_id, bbox in enumerate(boxes): landms = landmarks[crop_id].cpu().numpy().reshape([5, 2]) colors = [(255, 0, 0), (128, 255, 0), (255, 178, 102), (102, 128, 255), (0, 255, 255)] for i, (x, y) in enumerate(landms): vis_image = cv2.circle(vis_image, (x, y), radius=3, color=colors[i], thickness=3) x_min, y_min, x_max, y_max = bbox.cpu().numpy() x_min = np.clip(x_min, 0, x_max - 1) y_min = np.clip(y_min, 0, y_max - 1) vis_image = cv2.rectangle(vis_image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=2) cv2.imwrite(str(self.output_vis_path / f"{file_id}.jpg"), cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)) for crop_id, bbox in enumerate(boxes): bbox = bbox.cpu().numpy() labels += [{ "crop_id": crop_id, "bbox": bbox.tolist(), "score": scores[crop_id], "landmarks": landmarks[crop_id].tolist(), }] result = { "file_path": image_path, "file_id": file_id, "bboxes": labels } with open(self.output_label_path / f"{file_id}.json", "w") as f: json.dump(result, f, indent=2)
def main(): args = get_args() torch.distributed.init_process_group(backend="nccl") with open(args.config_path) as f: hparams = yaml.load(f, Loader=yaml.SafeLoader) hparams.update({ "json_path": args.output_path, "visualize": args.visualize, "confidence_threshold": args.confidence_threshold, "nms_threshold": args.nms_threshold, "keep_top_k": args.keep_top_k, "local_rank": args.local_rank, "prior_box": object_from_dict(hparams["prior_box"], image_size=[args.max_size, args.max_size]), "fp16": args.fp16, "folder_in_name": args.folder_in_name, }) if args.visualize: output_vis_path = args.output_path / "viz" output_vis_path.mkdir(parents=True, exist_ok=True) hparams["output_vis_path"] = output_vis_path output_label_path = args.output_path / "labels" output_label_path.mkdir(parents=True, exist_ok=True) hparams["output_label_path"] = output_label_path device = torch.device("cuda", args.local_rank) model = object_from_dict(hparams["model"]) model = model.to(device) if args.fp16: model = model.half() corrections: Dict[str, str] = {"model.": ""} checkpoint = load_checkpoint(file_path=args.weight_path, rename_in_layers=corrections) model.load_state_dict(checkpoint["state_dict"]) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) file_paths = [] for regexp in ["*.jpg", "*.png", "*.jpeg", "*.JPG"]: file_paths += sorted([x for x in tqdm(args.input_path.rglob(regexp))]) dataset = InferenceDataset(file_paths, max_size=args.max_size, transform=from_dict(hparams["test_aug"])) sampler = DistributedSampler(dataset, shuffle=False) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, sampler=sampler, ) predict(dataloader, model, hparams, device)
"file_paths": sorted([x for x in args.input_path.rglob("*") if x.is_file()]), "json_path": args.output_path, "output_path": args.output_path, "visualize": args.visualize, "origin_size": args.origin_size, "max_size": args.max_size, "target_size": args.target_size, "confidence_threshold": args.confidence_threshold, "nms_threshold": args.nms_threshold, "keep_top_k": args.keep_top_k, }) hparams["trainer"]["gpus"] = 1 # Right now we work only with one GPU model = InferenceModel(hparams, weight_path=args.weight_path) trainer = object_from_dict( hparams["trainer"], checkpoint_callback=object_from_dict(hparams["checkpoint_callback"]), ) trainer.test(model)
def main(): args = get_args() torch.set_grad_enabled(False) with open(args.config_path) as f: hparams = yaml.load(f, Loader=yaml.SafeLoader) device = torch.device("cpu" if args.cpu else "cuda") net = get_model(hparams, args.weights, args.fp16, device) file_paths = sorted(args.input_path.rglob("*.jpg")) output_path = args.output_path output_vis_path = output_path / "viz" output_label_path = output_path / "labels" output_image_path = output_path / "images" prepare_output_folders(output_vis_path, output_label_path, output_image_path, args.save_boxes, args.save_crops, args.visualize) transform = from_dict(hparams["test_aug"]) test_loader = DataLoader( InferenceDataset(file_paths, args.origin_size, transform=transform), batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) with torch.no_grad(): for raw_input in tqdm(test_loader): torched_images = raw_input["torched_image"].type(net.dtype) resizes = raw_input["resize"] image_paths = raw_input["image_path"] raw_images = raw_input["raw_image"] labels = [] if (args.batch_size == 1 and args.save_boxes and (output_label_path / f"{Path(image_paths[0]).stem}.json").exists()): continue loc, conf, land = net(torched_images.to(device)) # forward pass batch_size = torched_images.shape[0] image_height, image_width = torched_images.shape[2:] scale1 = torch.Tensor([ image_width, image_height, image_width, image_height, image_width, image_height, image_width, image_height, image_width, image_height, ]) scale1 = scale1.to(device) scale = torch.Tensor( [image_width, image_height, image_width, image_height]) scale = scale.to(device) priors = object_from_dict(hparams["prior_box"], image_size=(image_height, image_width)).to(loc.device) for batch_id in range(batch_size): image_path = image_paths[batch_id] file_id = Path(image_path).stem raw_image = raw_images[batch_id] resize = resizes[batch_id].float() boxes = decode(loc.data[batch_id], priors, hparams["test_parameters"]["variance"]) boxes *= scale / resize scores = conf[batch_id][:, 1] landmarks = decode_landm( land.data[batch_id], priors, hparams["test_parameters"]["variance"]) landmarks *= scale1 / resize # ignore low scores valid_index = torch.where( scores > args.confidence_threshold)[0] boxes = boxes[valid_index] landmarks = landmarks[valid_index] scores = scores[valid_index] order = scores.argsort(descending=True) boxes = boxes[order] landmarks = landmarks[order] scores = scores[order] # do NMS keep = nms(boxes, scores, args.nms_threshold) boxes = boxes[keep, :].int() landmarks = landmarks[keep].int() if boxes.shape[0] == 0: continue scores = scores[keep].cpu().numpy().astype(np.float64) if args.visualize: vis_image = raw_image.cpu().numpy().copy() for crop_id, bbox in enumerate(boxes): bbox = bbox.cpu().numpy() labels += [{ "crop_id": crop_id, "bbox": bbox.tolist(), "score": scores[crop_id], "landmarks": landmarks[crop_id].tolist(), }] if args.save_crops: x_min, y_min, x_max, y_max = bbox x_min = np.clip(x_min, 0, x_max - 1) y_min = np.clip(y_min, 0, y_max - 1) crop = raw_image[y_min:y_max, x_min:x_max].cpu().numpy() if args.visualize: vis_image = cv2.rectangle(vis_image, (x_min, y_min), (x_max, y_max), color=(255, 0, 0), thickness=3) target_folder = output_image_path / f"{file_id}" target_folder.mkdir(exist_ok=True, parents=True) crop_file_path = target_folder / f"{file_id}_{crop_id}.jpg" if crop_file_path.exists(): continue cv2.imwrite(str(crop_file_path), cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) if args.visualize: cv2.imwrite(str(output_vis_path / f"{file_id}.jpg"), cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)) if args.save_boxes: result = { "file_path": image_path, "file_id": file_id, "bboxes": labels, } with open(output_label_path / f"{file_id}.json", "w") as f: json.dump(result, f, indent=2)