def test_dataset() -> None: rows = read_train_rows("/store/points.json") dataset = CharDataset(rows=rows, transforms=test_transforms) for i in range(10): id, img, points, labels = dataset[2] plot = DetectionPlot(inv_normalize(img)) plot.draw_points(points, color="red", size=0.5) plot.save(f"/store/test-{i}.jpg")
def eval_step() -> None: model.eval() loss_meter = MeanMeter() metrics = MeanAveragePrecision(iou_threshold=0.3, num_classes=config.num_classes) for ids, image_batch, gt_point_batch, gt_label_batch in tqdm( test_loader): image_batch = image_batch.to(device) gt_point_batch = [x.to(device) for x in gt_point_batch] gt_label_batch = [x.to(device) for x in gt_label_batch] _, _, h, w = image_batch.shape netout = model(image_batch) _, _, hm_h, hm_w = netout.shape gt_hms = config.mkmaps(gt_point_batch, gt_label_batch, w=hm_w, h=hm_h) loss = config.hmloss( netout, gt_hms, ) point_batch, confidence_batch, label_batch = config.to_points( netout, h=h, w=w) loss_meter.update(loss.item()) for ( points, gt_points, labels, gt_labels, confidences, image, gt_hm, id, ) in zip( point_batch, gt_point_batch, label_batch, gt_label_batch, confidence_batch, image_batch, gt_hms, ids, ): plot = DetectionPlot(inv_normalize(image)) plot.draw_points(points, color="blue", size=w / 100) plot.draw_points(gt_points, color="red", size=w / 150) plot.save(f"{config.out_dir}/{id}-points-.png") logs["test_loss"] = loss_meter.get_value() model_loader.save_if_needed( model, loss.item(), )